diff --git a/include/mg_procedure.h b/include/mg_procedure.h index 3e05955e1..ce20d9b8d 100644 --- a/include/mg_procedure.h +++ b/include/mg_procedure.h @@ -1405,36 +1405,68 @@ int mgp_must_abort(struct mgp_graph *graph); /// @} -/// @name Kafka message API -/// Currently the API below is for kafka only but in the future -/// mgp_message and mgp_messages might be generic to support -/// other streaming systems. +/// @name Stream Source message API +/// API for accessing specific data contained in a mgp_message +/// used for defining transformation procedures. +/// Not all methods are available for all stream sources +/// so make sure that your transformation procedure can be used +/// for a specific source, i.e. only valid methods are used. ///@{ -/// A single Kafka message +/// A single Stream source message struct mgp_message; -/// A list of Kafka messages +/// A list of Stream source messages struct mgp_messages; +/// Stream source type. +enum mgp_source_type { + KAFKA, + PULSAR, +}; + +/// Get the type of the stream source that produced the message. +enum mgp_error mgp_message_source_type(struct mgp_message *message, enum mgp_source_type *result); + /// Payload is not null terminated and not a string but rather a byte array. /// You need to call mgp_message_payload_size() first, to read the size of /// the payload. +/// Supported stream sources: +/// - Kafka +/// - Pulsar +/// Return MGP_ERROR_INVALID_ARGUMENT if the message is from an unsupported stream source. enum mgp_error mgp_message_payload(struct mgp_message *message, const char **result); /// Get the payload size +/// Supported stream sources: +/// - Kafka +/// - Pulsar +/// Return MGP_ERROR_INVALID_ARGUMENT if the message is from an unsupported stream source. enum mgp_error mgp_message_payload_size(struct mgp_message *message, size_t *result); /// Get the name of topic +/// Supported stream sources: +/// - Kafka +/// - Pulsar +/// Return MGP_ERROR_INVALID_ARGUMENT if the message is from an unsupported stream source. enum mgp_error mgp_message_topic_name(struct mgp_message *message, const char **result); /// Get the key of mgp_message as a byte array +/// Supported stream sources: +/// - Kafka +/// Return MGP_ERROR_INVALID_ARGUMENT if the message is from an unsupported stream source. enum mgp_error mgp_message_key(struct mgp_message *message, const char **result); /// Get the key size of mgp_message +/// Supported stream sources: +/// - Kafka +/// Return MGP_ERROR_INVALID_ARGUMENT if the message is from an unsupported stream source. enum mgp_error mgp_message_key_size(struct mgp_message *message, size_t *result); /// Get the timestamp of mgp_message as a byte array +/// Supported stream sources: +/// - Kafka +/// Return MGP_ERROR_INVALID_ARGUMENT if the message is from an unsupported stream source. enum mgp_error mgp_message_timestamp(struct mgp_message *message, int64_t *result); /// Get the message offset from a message. diff --git a/include/mgp.py b/include/mgp.py index 745399f9b..81913ac39 100644 --- a/include/mgp.py +++ b/include/mgp.py @@ -1260,6 +1260,9 @@ class InvalidMessageError(Exception): pass +SOURCE_TYPE_KAFKA = _mgp.SOURCE_TYPE_KAFKA +SOURCE_TYPE_PULSAR = _mgp.SOURCE_TYPE_PULSAR + class Message: """Represents a message from a stream.""" __slots__ = ('_message',) @@ -1280,22 +1283,58 @@ class Message: """Return True if `self` is in valid context and may be used.""" return self._message.is_valid() + def source_type(self) -> str: + """ + Supported in all stream sources + + Raise InvalidArgumentError if the message is from an unsupported stream source. + """ + if not self.is_valid(): + raise InvalidMessageError() + return self._message.source_type() + def payload(self) -> bytes: + """ + Supported stream sources: + - Kafka + - Pulsar + + Raise InvalidArgumentError if the message is from an unsupported stream source. + """ if not self.is_valid(): raise InvalidMessageError() return self._message.payload() def topic_name(self) -> str: + """ + Supported stream sources: + - Kafka + - Pulsar + + Raise InvalidArgumentError if the message is from an unsupported stream source. + """ if not self.is_valid(): raise InvalidMessageError() return self._message.topic_name() def key(self) -> bytes: + """ + Supported stream sources: + - Kafka + + Raise InvalidArgumentError if the message is from an unsupported stream source. + """ if not self.is_valid(): raise InvalidMessageError() return self._message.key() def timestamp(self) -> int: + """ + Supported stream sources: + - Kafka + + Raise InvalidArgumentError if the message is from an unsupported stream source. + """ if not self.is_valid(): raise InvalidMessageError() return self._message.timestamp() diff --git a/libs/setup.sh b/libs/setup.sh index 0099d7c37..df0a1235e 100755 --- a/libs/setup.sh +++ b/libs/setup.sh @@ -125,8 +125,8 @@ declare -A primary_urls=( ["neo4j"]="http://$local_cache_host/file/neo4j-community-3.2.3-unix.tar.gz" ["librdkafka"]="http://$local_cache_host/git/librdkafka.git" ["protobuf"]="http://$local_cache_host/git/protobuf.git" - ["boost"]="https://boostorg.jfrog.io/artifactory/main/release/1.77.0/source/boost_1_77_0.tar.gz" - ["pulsar"]="https://github.com/apache/pulsar.git" + ["boost"]="http://$local_cache_host/file/boost_1_77_0.tar.gz" + ["pulsar"]="http://$local_cache_host/git/pulsar.git" ) # The goal of secondary urls is to have links to the "source of truth" of @@ -271,7 +271,7 @@ repo_clone_try_double "${primary_urls[librdkafka]}" "${secondary_urls[librdkafka protobuf_tag="v3.12.4" repo_clone_try_double "${primary_urls[protobuf]}" "${secondary_urls[protobuf]}" "protobuf" "$protobuf_tag" true pushd protobuf -./autogen.sh && ./configure --prefix=$(pwd)/lib +./autogen.sh && ./configure CC=clang CXX=clang++ --prefix=$(pwd)/lib popd # boost @@ -279,8 +279,8 @@ file_get_try_double "${primary_urls[boost]}" "${secondary_urls[boost]}" tar -xzf boost_1_77_0.tar.gz mv boost_1_77_0 boost pushd boost -./bootstrap.sh --prefix=$(pwd)/lib --with-libraries="system,regex" -./b2 -j$(nproc) install variant=release +./bootstrap.sh --prefix=$(pwd)/lib --with-libraries="system,regex" --with-toolset=clang +./b2 toolset=clang -j$(nproc) install variant=release popd #pulsar diff --git a/release/package/run.sh b/release/package/run.sh index af6ef5154..0809c1517 100755 --- a/release/package/run.sh +++ b/release/package/run.sh @@ -72,7 +72,7 @@ make_package () { docker exec "$build_container" bash -c "/memgraph/environment/os/$os.sh install MEMGRAPH_BUILD_DEPS" echo "Building targeted package..." - docker exec "$build_container" bash -c "cd /memgraph && ./init" + docker exec "$build_container" bash -c "cd /memgraph && $ACTIVATE_TOOLCHAIN && ./init" docker exec "$build_container" bash -c "cd $container_build_dir && rm -rf ./*" docker exec "$build_container" bash -c "cd $container_build_dir && $ACTIVATE_TOOLCHAIN && cmake -DCMAKE_BUILD_TYPE=release $telemetry_id_override_flag .." # ' is used instead of " because we need to run make within the allowed diff --git a/src/query/procedure/mg_procedure_impl.cpp b/src/query/procedure/mg_procedure_impl.cpp index 1130f4239..0ef3002c3 100644 --- a/src/query/procedure/mg_procedure_impl.cpp +++ b/src/query/procedure/mg_procedure_impl.cpp @@ -26,6 +26,7 @@ #include "module.hpp" #include "query/procedure/cypher_types.hpp" #include "query/procedure/mg_procedure_helpers.hpp" +#include "query/stream/common.hpp" #include "storage/v2/property_value.hpp" #include "storage/v2/view.hpp" #include "utils/algorithm.hpp" @@ -2493,14 +2494,53 @@ bool IsValidIdentifierName(const char *name) { } // namespace query::procedure +namespace { +class InvalidMessageFunction : public std::invalid_argument { + public: + InvalidMessageFunction(const query::StreamSourceType type, const std::string_view function_name) + : std::invalid_argument{fmt::format("'{}' is not defined for a message from a stream of type '{}'", function_name, + query::StreamSourceTypeToString(type))} {} +}; + +query::StreamSourceType MessageToStreamSourceType(const mgp_message::KafkaMessage & /*msg*/) { + return query::StreamSourceType::KAFKA; +} + +query::StreamSourceType MessageToStreamSourceType(const mgp_message::PulsarMessage & /*msg*/) { + return query::StreamSourceType::PULSAR; +} + +mgp_source_type StreamSourceTypeToMgpSourceType(const query::StreamSourceType type) { + switch (type) { + case query::StreamSourceType::KAFKA: + return mgp_source_type::KAFKA; + case query::StreamSourceType::PULSAR: + return mgp_source_type::PULSAR; + } +} + +} // namespace + +mgp_error mgp_message_source_type(mgp_message *message, mgp_source_type *result) { + return WrapExceptions( + [message] { + return std::visit(utils::Overloaded{[](const auto &message) { + return StreamSourceTypeToMgpSourceType(MessageToStreamSourceType(message)); + }}, + message->msg); + }, + result); +} + mgp_error mgp_message_payload(mgp_message *message, const char **result) { return WrapExceptions( [message] { - return std::visit( - utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Payload().data(); }, - [](const mgp_message::PulsarMessage &msg) { return msg.Payload().data(); }, - [](const auto & /*other*/) { throw std::invalid_argument("Invalid source type"); }}, - message->msg); + return std::visit(utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Payload().data(); }, + [](const mgp_message::PulsarMessage &msg) { return msg.Payload().data(); }, + [](const auto &msg) -> const char * { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "payload"); + }}, + message->msg); }, result); } @@ -2508,11 +2548,13 @@ mgp_error mgp_message_payload(mgp_message *message, const char **result) { mgp_error mgp_message_payload_size(mgp_message *message, size_t *result) { return WrapExceptions( [message] { - return std::visit( - utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Payload().size(); }, - [](const mgp_message::PulsarMessage &msg) { return msg.Payload().size(); }, - [](const auto & /*other*/) { throw std::invalid_argument("Invalid source type"); }}, - message->msg); + return std::visit(utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Payload().size(); }, + [](const mgp_message::PulsarMessage &msg) { return msg.Payload().size(); }, + [](const auto &msg) -> size_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), + "payload_size"); + }}, + message->msg); }, result); } @@ -2523,7 +2565,9 @@ mgp_error mgp_message_topic_name(mgp_message *message, const char **result) { return std::visit( utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->TopicName().data(); }, [](const mgp_message::PulsarMessage &msg) { return msg.TopicName().data(); }, - [](const auto & /*other*/) { throw std::invalid_argument("Invalid source type"); }}, + [](const auto &msg) -> const char * { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "topic_name"); + }}, message->msg); }, result); @@ -2532,16 +2576,11 @@ mgp_error mgp_message_topic_name(mgp_message *message, const char **result) { mgp_error mgp_message_key(mgp_message *message, const char **result) { return WrapExceptions( [message] { - return std::visit( - [](T &&msg) -> const char * { - using MessageType = std::decay_t; - if constexpr (std::same_as) { - return msg->Key().data(); - } else { - throw std::invalid_argument("Invalid source type"); - } - }, - message->msg); + return std::visit(utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Key().data(); }, + [](const auto &msg) -> const char * { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "key"); + }}, + message->msg); }, result); } @@ -2549,16 +2588,11 @@ mgp_error mgp_message_key(mgp_message *message, const char **result) { mgp_error mgp_message_key_size(mgp_message *message, size_t *result) { return WrapExceptions( [message] { - return std::visit( - [](T &&msg) -> size_t { - using MessageType = std::decay_t; - if constexpr (std::same_as) { - return msg->Key().size(); - } else { - throw std::invalid_argument("Invalid source type"); - } - }, - message->msg); + return std::visit(utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Key().size(); }, + [](const auto &msg) -> size_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "key_size"); + }}, + message->msg); }, result); } @@ -2566,16 +2600,11 @@ mgp_error mgp_message_key_size(mgp_message *message, size_t *result) { mgp_error mgp_message_timestamp(mgp_message *message, int64_t *result) { return WrapExceptions( [message] { - return std::visit( - [](T &&msg) -> int64_t { - using MessageType = std::decay_t; - if constexpr (std::same_as) { - return msg->Timestamp(); - } else { - throw std::invalid_argument("Invalid source type"); - } - }, - message->msg); + return std::visit(utils::Overloaded{[](const mgp_message::KafkaMessage &msg) { return msg->Timestamp(); }, + [](const auto &msg) -> int64_t { + throw InvalidMessageFunction(MessageToStreamSourceType(msg), "timestamp"); + }}, + message->msg); }, result); } diff --git a/src/query/procedure/py_module.cpp b/src/query/procedure/py_module.cpp index 768098005..b59fbb16b 100644 --- a/src/query/procedure/py_module.cpp +++ b/src/query/procedure/py_module.cpp @@ -19,6 +19,7 @@ #include #include +#include "mg_procedure.h" #include "query/procedure/mg_procedure_helpers.hpp" #include "query/procedure/mg_procedure_impl.hpp" #include "utils/memory.hpp" @@ -555,6 +556,21 @@ PyObject *PyMessageIsValid(PyMessage *self, PyObject *Py_UNUSED(ignored)) { return PyMessagesIsValid(self->messages, nullptr); } +PyObject *PyMessageGetSourceType(PyMessage *self, PyObject *Py_UNUSED(ignored)) { + MG_ASSERT(self->message); + MG_ASSERT(self->memory); + mgp_source_type source_type{mgp_source_type::KAFKA}; + if (RaiseExceptionFromErrorCode(mgp_message_source_type(self->message, &source_type))) { + return nullptr; + } + auto *py_source_type = PyLong_FromLong(static_cast(source_type)); + if (!py_source_type) { + PyErr_SetString(PyExc_RuntimeError, "Unable to get long from source type"); + return nullptr; + } + return py_source_type; +} + PyObject *PyMessageGetPayload(PyMessage *self, PyObject *Py_UNUSED(ignored)) { MG_ASSERT(self->message); size_t payload_size{0}; @@ -582,7 +598,7 @@ PyObject *PyMessageGetTopicName(PyMessage *self, PyObject *Py_UNUSED(ignored)) { } auto *py_topic_name = PyUnicode_FromString(topic_name); if (!py_topic_name) { - PyErr_SetString(PyExc_RuntimeError, "Unable to get raw bytes from payload"); + PyErr_SetString(PyExc_RuntimeError, "Unable to get string from topic_name"); return nullptr; } return py_topic_name; @@ -642,6 +658,7 @@ static PyMethodDef PyMessageMethods[] = { {"__reduce__", reinterpret_cast(DisallowPickleAndCopy), METH_NOARGS, "__reduce__ is not supported"}, {"is_valid", reinterpret_cast(PyMessageIsValid), METH_NOARGS, "Return True if messages is in valid context and may be used."}, + {"source_type", reinterpret_cast(PyMessageGetSourceType), METH_NOARGS, "Get stream source type."}, {"payload", reinterpret_cast(PyMessageGetPayload), METH_NOARGS, "Get payload"}, {"topic_name", reinterpret_cast(PyMessageGetTopicName), METH_NOARGS, "Get topic name."}, {"key", reinterpret_cast(PyMessageGetKey), METH_NOARGS, "Get message key."}, @@ -1921,6 +1938,18 @@ struct PyMgpError { const char *docstring; }; +bool AddModuleConstants(PyObject &module) { + // add source type constants + if (PyModule_AddIntConstant(&module, "SOURCE_TYPE_KAFKA", static_cast(mgp_source_type::KAFKA))) { + return false; + } + if (PyModule_AddIntConstant(&module, "SOURCE_TYPE_PULSAR", static_cast(mgp_source_type::PULSAR))) { + return false; + } + + return true; +} + PyObject *PyInitMgpModule() { PyObject *mgp = PyModule_Create(&PyMgpModule); if (!mgp) return nullptr; @@ -1937,6 +1966,9 @@ PyObject *PyInitMgpModule() { } return true; }; + + if (!AddModuleConstants(*mgp)) return nullptr; + if (!register_type(&PyPropertiesIteratorType, "PropertiesIterator")) return nullptr; if (!register_type(&PyVerticesIteratorType, "VerticesIterator")) return nullptr; if (!register_type(&PyEdgesIteratorType, "EdgesIterator")) return nullptr; diff --git a/src/query/stream/common.hpp b/src/query/stream/common.hpp index 2633ed4f9..5ad1280b0 100644 --- a/src/query/stream/common.hpp +++ b/src/query/stream/common.hpp @@ -66,9 +66,9 @@ enum class StreamSourceType : uint8_t { KAFKA, PULSAR }; constexpr std::string_view StreamSourceTypeToString(StreamSourceType type) { switch (type) { case StreamSourceType::KAFKA: - return "KAFKA"; + return "kafka"; case StreamSourceType::PULSAR: - return "PULSAR"; + return "pulsar"; } } diff --git a/tests/e2e/streams/common.py b/tests/e2e/streams/common.py index d92dbd0d1..9e29b4e8d 100644 --- a/tests/e2e/streams/common.py +++ b/tests/e2e/streams/common.py @@ -133,7 +133,7 @@ def check_stream_info(cursor, stream_name, expected_stream_info): def kafka_check_vertex_exists_with_topic_and_payload(cursor, topic, payload_bytes): decoded_payload = payload_bytes.decode('utf-8') check_vertex_exists_with_properties( - cursor, {'topic': f'"{topic}"', 'payload': f'"{decoded_payload}"'}) + cursor, {'topic': f'"{topic}"', 'payload': f'"{decoded_payload}"'}) def pulsar_default_namespace_topic(topic): diff --git a/tests/e2e/streams/kafka_streams_tests.py b/tests/e2e/streams/kafka_streams_tests.py index 4dc1b2084..8206222df 100755 --- a/tests/e2e/streams/kafka_streams_tests.py +++ b/tests/e2e/streams/kafka_streams_tests.py @@ -166,12 +166,8 @@ def test_check_stream( assert f"payload: '{message_as_str}'" in test_results[i][common.QUERY] assert test_results[i][common.PARAMS] is None else: - assert test_results[i][common.QUERY] == ( - "CREATE (n:MESSAGE " - "{timestamp: $timestamp, " - "payload: $payload, " - "topic: $topic})" - ) + assert f"payload: $payload" in test_results[i][ + common.QUERY] and f"topic: $topic" in test_results[i][common.QUERY] parameters = test_results[i][common.PARAMS] # this is not a very sofisticated test, but checks if # timestamp has some kind of value @@ -218,7 +214,13 @@ def test_show_streams(kafka_producer, kafka_topics, connection): common.check_stream_info( cursor, "default_values", - ("default_values", "KAFKA", None, None, "kafka_transform.simple", None, False), + ("default_values", + "kafka", + None, + None, + "kafka_transform.simple", + None, + False), ) common.check_stream_info( @@ -226,7 +228,7 @@ def test_show_streams(kafka_producer, kafka_topics, connection): "complex_values", ( "complex_values", - "KAFKA", + "kafka", batch_interval, batch_size, "kafka_transform.with_parameters", diff --git a/tests/e2e/streams/pulsar_streams_tests.py b/tests/e2e/streams/pulsar_streams_tests.py index a64dfbcec..68516f12a 100755 --- a/tests/e2e/streams/pulsar_streams_tests.py +++ b/tests/e2e/streams/pulsar_streams_tests.py @@ -27,7 +27,8 @@ def check_vertex_exists_with_topic_and_payload(cursor, topic, payload_byte): decoded_payload = payload_byte.decode('utf-8') common.check_vertex_exists_with_properties( cursor, { - 'topic': f'"{common.pulsar_default_namespace_topic(topic)}"', 'payload': f'"{decoded_payload}"'}) + 'topic': f'"{common.pulsar_default_namespace_topic(topic)}"', + 'payload': f'"{decoded_payload}"'}) @pytest.mark.parametrize("transformation", TRANSFORMATIONS_TO_CHECK) @@ -205,11 +206,8 @@ def test_check_stream( assert f"payload: '{message_as_str}'" in test_results[i][common.QUERY] assert test_results[i][common.PARAMS] is None else: - assert test_results[i][common.QUERY] == ( - "CREATE (n:MESSAGE " - "{payload: $payload, " - "topic: $topic})" - ) + assert f"payload: $payload" in test_results[i][ + common.QUERY] and f"topic: $topic" in test_results[i][common.QUERY] parameters = test_results[i][common.PARAMS] assert parameters["topic"] == common.pulsar_default_namespace_topic( pulsar_topics[0]) @@ -251,7 +249,13 @@ def test_show_streams(pulsar_client, pulsar_topics, connection): common.check_stream_info( cursor, "default_values", - ("default_values", "PULSAR", None, None, "pulsar_transform.simple", None, False), + ("default_values", + "pulsar", + None, + None, + "pulsar_transform.simple", + None, + False), ) common.check_stream_info( @@ -259,7 +263,7 @@ def test_show_streams(pulsar_client, pulsar_topics, connection): "complex_values", ( "complex_values", - "PULSAR", + "pulsar", batch_interval, batch_size, "pulsar_transform.with_parameters", diff --git a/tests/e2e/streams/streams_owner_tests.py b/tests/e2e/streams/streams_owner_tests.py index 850fd34b9..d67927155 100644 --- a/tests/e2e/streams/streams_owner_tests.py +++ b/tests/e2e/streams/streams_owner_tests.py @@ -77,7 +77,7 @@ def test_owner_is_shown(kafka_topics, connection): f"TOPICS {kafka_topics[0]} " f"TRANSFORM kafka_transform.simple") - common.check_stream_info(userless_cursor, "test", ("test", "KAFKA", None, None, + common.check_stream_info(userless_cursor, "test", ("test", "kafka", None, None, "kafka_transform.simple", stream_user, False)) diff --git a/tests/e2e/streams/transformations/kafka_transform.py b/tests/e2e/streams/transformations/kafka_transform.py index 286e0dad8..52f234bd0 100644 --- a/tests/e2e/streams/transformations/kafka_transform.py +++ b/tests/e2e/streams/transformations/kafka_transform.py @@ -21,19 +21,18 @@ def simple( for i in range(0, messages.total_messages()): message = messages.message_at(i) + assert message.source_type() == mgp.SOURCE_TYPE_KAFKA payload_as_str = message.payload().decode("utf-8") - offset = message.offset() result_queries.append( mgp.Record( - query=( - f"CREATE (n:MESSAGE {{timestamp: '{message.timestamp()}', " - f"payload: '{payload_as_str}', " - f"topic: '{message.topic_name()}', " - f"offset: {offset}}})" - ), - parameters=None, - ) - ) + query=f""" + CREATE (n:MESSAGE {{ + timestamp: '{message.timestamp()}', + payload: '{payload_as_str}', + offset: {message.offset()} + topic: '{message.topic_name()}' + }})""", + parameters=None)) return result_queries @@ -47,25 +46,22 @@ def with_parameters( for i in range(0, messages.total_messages()): message = messages.message_at(i) + assert message.source_type() == mgp.SOURCE_TYPE_KAFKA payload_as_str = message.payload().decode("utf-8") - offset = message.offset() result_queries.append( mgp.Record( - query=( - "CREATE (n:MESSAGE " - "{timestamp: $timestamp, " - "payload: $payload, " - "topic: $topic, " - "offset: $offset})" - ), + query=""" + CREATE (n:MESSAGE { + timestamp: $timestamp, + payload: $payload, + offset: $offset, + topic: $topic + })""", parameters={ "timestamp": message.timestamp(), "payload": payload_as_str, - "topic": message.topic_name(), - "offset": offset, - }, - ) - ) + "offset": message.offset(), + "topic": message.topic_name()})) return result_queries @@ -78,6 +74,7 @@ def query( for i in range(0, messages.total_messages()): message = messages.message_at(i) + assert message.source_type() == mgp.SOURCE_TYPE_KAFKA payload_as_str = message.payload().decode("utf-8") result_queries.append( mgp.Record(query=payload_as_str, parameters=None) diff --git a/tests/e2e/streams/transformations/pulsar_transform.py b/tests/e2e/streams/transformations/pulsar_transform.py index f177db354..5379db52c 100644 --- a/tests/e2e/streams/transformations/pulsar_transform.py +++ b/tests/e2e/streams/transformations/pulsar_transform.py @@ -21,10 +21,16 @@ def simple(context: mgp.TransCtx, for i in range(0, messages.total_messages()): message = messages.message_at(i) + assert message.source_type() == mgp.SOURCE_TYPE_PULSAR payload_as_str = message.payload().decode("utf-8") - result_queries.append(mgp.Record( - query=f"CREATE (n:MESSAGE {{payload: '{payload_as_str}', topic: '{message.topic_name()}'}})", - parameters=None)) + result_queries.append( + mgp.Record( + query=f""" + CREATE (n:MESSAGE {{ + payload: '{payload_as_str}', + topic: '{message.topic_name()}' + }})""", + parameters=None)) return result_queries @@ -38,10 +44,18 @@ def with_parameters(context: mgp.TransCtx, for i in range(0, messages.total_messages()): message = messages.message_at(i) + assert message.source_type() == mgp.SOURCE_TYPE_PULSAR payload_as_str = message.payload().decode("utf-8") - result_queries.append(mgp.Record( - query="CREATE (n:MESSAGE {payload: $payload, topic: $topic})", - parameters={"payload": payload_as_str, "topic": message.topic_name()})) + result_queries.append( + mgp.Record( + query=""" + CREATE (n:MESSAGE { + payload: $payload, + topic: $topic + })""", + parameters={ + "payload": payload_as_str, + "topic": message.topic_name()})) return result_queries @@ -53,9 +67,9 @@ def query(messages: mgp.Messages for i in range(0, messages.total_messages()): message = messages.message_at(i) + assert message.source_type() == mgp.SOURCE_TYPE_PULSAR payload_as_str = message.payload().decode("utf-8") result_queries.append(mgp.Record( query=payload_as_str, parameters=None)) return result_queries - diff --git a/tests/unit/mgp_kafka_c_api.cpp b/tests/unit/mgp_kafka_c_api.cpp index 074efdc9c..ee3a1a00b 100644 --- a/tests/unit/mgp_kafka_c_api.cpp +++ b/tests/unit/mgp_kafka_c_api.cpp @@ -20,6 +20,7 @@ #include "gtest/gtest.h" #include "integrations/kafka/consumer.hpp" #include "query/procedure/mg_procedure_impl.hpp" +#include "query/stream/common.hpp" #include "test_utils.hpp" #include "utils/pmr/vector.hpp" @@ -156,6 +157,8 @@ TEST_F(MgpApiTest, TestAllMgpKafkaCApi) { EXPECT_EQ(EXPECT_MGP_NO_ERROR(size_t, mgp_message_key_size, message), 1); EXPECT_EQ(*EXPECT_MGP_NO_ERROR(const char *, mgp_message_key, message), expected[i].key); + // Test for source type + EXPECT_EQ(EXPECT_MGP_NO_ERROR(mgp_source_type, mgp_message_source_type, message), mgp_source_type::KAFKA); // Test for payload size EXPECT_EQ(EXPECT_MGP_NO_ERROR(size_t, mgp_message_payload_size, message), expected[i].payload_size); // Test for payload