Add basic queries for managing streams ()

* Add CREATE, START, STOP and DROP queries

* Fix definition of port in replica query

* Explicitly stop the consumer before removing

* Fix offset committing in Consumer

* Add tests for basic stream queries

* Remove unnecessary WITH keywords from CREATE query

* Add tests

* Add STREAM privilege

* Disable not working test

The functionality is tested manually, but I couldn't make it work with
the mock kafka cluster.

* Add support for multiple topic names

* Replace skiplist by synchronized map

* Make Consumer::Test const and improve error handling

The improvement in the error handling is mostly done regarding to the
Test function. Instead of trying to revert the assignments, Test just
stores the last commited assignment. When Start or Test is called, they
check for the last commited assignments, and if it is saved, then they
try to restore it. This way:
1. All the failures are returned to the user (failed to save/restore)
2. Failed assignment cannot terminate Memgraph

* Make Test do not block creating/droping other streams
This commit is contained in:
János Benjamin Antal 2021-06-28 17:21:13 +02:00 committed by Antonio Andelic
parent d80ff745eb
commit ac230d0c2d
22 changed files with 779 additions and 214 deletions

View File

@ -53,6 +53,8 @@ std::string PermissionToString(Permission permission) {
return "CONFIG";
case Permission::AUTH:
return "AUTH";
case Permission::STREAM:
return "STREAM";
}
}

View File

@ -27,7 +27,8 @@ enum class Permission : uint64_t {
FREE_MEMORY = 1U << 13U,
TRIGGER = 1U << 14U,
CONFIG = 1U << 15U,
AUTH = 1U << 16U
AUTH = 1U << 16U,
STREAM = 1U << 17U
};
// clang-format on
@ -37,7 +38,7 @@ const std::vector<Permission> kPermissionsAll = {Permission::MATCH, Permiss
Permission::INDEX, Permission::STATS, Permission::CONSTRAINT,
Permission::DUMP, Permission::AUTH, Permission::REPLICATION,
Permission::DURABILITY, Permission::READ_FILE, Permission::FREE_MEMORY,
Permission::TRIGGER, Permission::CONFIG};
Permission::TRIGGER, Permission::CONFIG, Permission::STREAM};
// Function that converts a permission to its string representation.
std::string PermissionToString(Permission permission);

View File

@ -38,6 +38,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) {
return auth::Permission::CONFIG;
case query::AuthQuery::Privilege::AUTH:
return auth::Permission::AUTH;
case query::AuthQuery::Privilege::STREAM:
return auth::Permission::STREAM;
}
}
} // namespace glue

View File

@ -5,6 +5,8 @@
#include <memory>
#include <unordered_set>
#include <librdkafka/rdkafkacpp.h>
#include <spdlog/spdlog.h>
#include "integrations/kafka/exceptions.hpp"
#include "utils/exceptions.hpp"
#include "utils/logging.hpp"
@ -17,6 +19,51 @@ constexpr std::chrono::milliseconds kDefaultBatchInterval{100};
constexpr int64_t kDefaultBatchSize = 1000;
constexpr int64_t kDefaultTestBatchLimit = 1;
namespace {
utils::BasicResult<std::string, std::vector<Message>> GetBatch(RdKafka::KafkaConsumer &consumer,
const ConsumerInfo &info,
std::atomic<bool> &is_running) {
std::vector<Message> batch{};
int64_t batch_size = info.batch_size.value_or(kDefaultBatchSize);
batch.reserve(batch_size);
auto remaining_timeout_in_ms = info.batch_interval.value_or(kDefaultBatchInterval).count();
auto start = std::chrono::steady_clock::now();
bool run_batch = true;
for (int64_t i = 0; remaining_timeout_in_ms > 0 && i < batch_size && is_running.load(); ++i) {
std::unique_ptr<RdKafka::Message> msg(consumer.consume(remaining_timeout_in_ms));
switch (msg->err()) {
case RdKafka::ERR__TIMED_OUT:
run_batch = false;
break;
case RdKafka::ERR_NO_ERROR:
batch.emplace_back(std::move(msg));
break;
default:
auto error = msg->errstr();
spdlog::warn("Unexpected error while consuming message in consumer {}, error: {}!", info.consumer_name,
msg->errstr());
return {std::move(error)};
}
if (!run_batch) {
break;
}
auto now = std::chrono::steady_clock::now();
auto took = std::chrono::duration_cast<std::chrono::milliseconds>(now - start);
remaining_timeout_in_ms = remaining_timeout_in_ms - took.count();
start = now;
}
return {std::move(batch)};
}
} // namespace
Message::Message(std::unique_ptr<RdKafka::Message> &&message) : message_{std::move(message)} {
// Because of these asserts, the message can be safely accessed in the member function functions, because it cannot
// be null and always points to a valid message (not to a wrapped error)
@ -62,7 +109,7 @@ Consumer::Consumer(const std::string &bootstrap_servers, ConsumerInfo info, Cons
throw ConsumerFailedToInitializeException(info_.consumer_name, error);
}
if (conf->set("enable.auto.offset.store", "false", error) != RdKafka::Conf::CONF_OK) {
if (conf->set("enable.auto.commit", "false", error) != RdKafka::Conf::CONF_OK) {
throw ConsumerFailedToInitializeException(info_.consumer_name, error);
}
@ -108,7 +155,10 @@ Consumer::Consumer(const std::string &bootstrap_servers, ConsumerInfo info, Cons
}
}
Consumer::~Consumer() { StopIfRunning(); }
Consumer::~Consumer() {
StopIfRunning();
RdKafka::TopicPartition::destroy(last_assignment_);
}
void Consumer::Start() {
if (is_running_) {
@ -141,47 +191,36 @@ void Consumer::StopIfRunning() {
}
}
void Consumer::Test(std::optional<int64_t> limit_batches, const ConsumerFunction &test_consumer_function) {
if (is_running_) {
void Consumer::Test(std::optional<int64_t> limit_batches, const ConsumerFunction &test_consumer_function) const {
// The implementation of this function is questionable: it is const qualified, though it changes the inner state of
// KafkaConsumer. Though it changes the inner state, it saves the current assignment for future Test/Start calls to
// restore the current state, so the changes made by this function shouldn't be visible for the users of the class. It
// also passes a non const reference of KafkaConsumer to GetBatch function. That means the object is bitwise const
// (KafkaConsumer is stored in unique_ptr) and internally mostly synchronized. Mostly, because as Start/Stop requires
// exclusive access to consumer, so we don't have to deal with simultaneous calls to those functions. The only concern
// in this function is to prevent executing this function on multiple threads simultaneously.
if (is_running_.exchange(true)) {
throw ConsumerRunningException(info_.consumer_name);
}
utils::OnScopeExit restore_is_running([this] { is_running_.store(false); });
if (last_assignment_.empty()) {
if (auto err = consumer_->assignment(last_assignment_); err != RdKafka::ERR_NO_ERROR) {
spdlog::warn("Saving the commited offset of consumer {} failed: {}", info_.consumer_name, RdKafka::err2str(err));
throw ConsumerTestFailedException(info_.consumer_name,
fmt::format("Couldn't save commited offsets: '{}'", RdKafka::err2str(err)));
}
} else {
if (auto err = consumer_->assign(last_assignment_); err != RdKafka::ERR_NO_ERROR) {
throw ConsumerTestFailedException(info_.consumer_name,
fmt::format("Couldn't restore commited offsets: '{}'", RdKafka::err2str(err)));
}
}
int64_t num_of_batches = limit_batches.value_or(kDefaultTestBatchLimit);
is_running_.store(true);
std::vector<std::unique_ptr<RdKafka::TopicPartition>> partitions;
{
// Save the current offsets in order to restore them in cleanup
std::vector<RdKafka::TopicPartition *> tmp_partitions;
if (const auto err = consumer_->assignment(tmp_partitions); err != RdKafka::ERR_NO_ERROR) {
throw ConsumerTestFailedException(info_.consumer_name, RdKafka::err2str(err));
}
if (const auto err = consumer_->position(tmp_partitions); err != RdKafka::ERR_NO_ERROR) {
throw ConsumerTestFailedException(info_.consumer_name, RdKafka::err2str(err));
}
partitions.reserve(tmp_partitions.size());
std::transform(
tmp_partitions.begin(), tmp_partitions.end(), std::back_inserter(partitions),
[](RdKafka::TopicPartition *const partition) { return std::unique_ptr<RdKafka::TopicPartition>{partition}; });
}
utils::OnScopeExit cleanup([this, &partitions]() {
is_running_.store(false);
std::vector<RdKafka::TopicPartition *> tmp_partitions;
tmp_partitions.reserve(partitions.size());
std::transform(partitions.begin(), partitions.end(), std::back_inserter(tmp_partitions),
[](const auto &partition) { return partition.get(); });
if (const auto err = consumer_->assign(tmp_partitions); err != RdKafka::ERR_NO_ERROR) {
spdlog::error("Couldn't restore previous offsets after testing Kafka consumer {}!", info_.consumer_name);
throw ConsumerTestFailedException(info_.consumer_name, RdKafka::err2str(err));
}
});
for (int64_t i = 0; i < num_of_batches;) {
auto maybe_batch = GetBatch();
auto maybe_batch = GetBatch(*consumer_, info_, is_running_);
if (maybe_batch.HasError()) {
throw ConsumerTestFailedException(info_.consumer_name, maybe_batch.GetError());
@ -218,6 +257,7 @@ void Consumer::event_cb(RdKafka::Event &event) {
break;
}
}
void Consumer::StartConsuming() {
MG_ASSERT(!is_running_, "Cannot start already running consumer!");
@ -229,6 +269,14 @@ void Consumer::StartConsuming() {
is_running_.store(true);
if (!last_assignment_.empty()) {
if (auto err = consumer_->assign(last_assignment_); err != RdKafka::ERR_NO_ERROR) {
throw ConsumerStartFailedException(info_.consumer_name,
fmt::format("Couldn't restore commited offsets: '{}'", RdKafka::err2str(err)));
}
RdKafka::TopicPartition::destroy(last_assignment_);
}
thread_ = std::thread([this] {
constexpr auto kMaxThreadNameSize = utils::GetMaxThreadNameSize();
const auto full_thread_name = "Cons#" + info_.consumer_name;
@ -236,7 +284,7 @@ void Consumer::StartConsuming() {
utils::ThreadSetName(full_thread_name.substr(0, kMaxThreadNameSize));
while (is_running_) {
auto maybe_batch = this->GetBatch();
auto maybe_batch = GetBatch(*consumer_, info_, is_running_);
if (maybe_batch.HasError()) {
spdlog::warn("Error happened in consumer {} while fetching messages: {}!", info_.consumer_name,
maybe_batch.GetError());
@ -251,8 +299,11 @@ void Consumer::StartConsuming() {
// TODO (mferencevic): Figure out what to do with all other exceptions.
try {
consumer_function_(batch);
consumer_->commitSync();
} catch (const utils::BasicException &e) {
if (auto err = consumer_->commitSync(); err != RdKafka::ERR_NO_ERROR) {
spdlog::warn("Committing offset of consumer {} failed: {}", info_.consumer_name, RdKafka::err2str(err));
break;
}
} catch (const std::exception &e) {
spdlog::warn("Error happened in consumer {} while processing a batch: {}!", info_.consumer_name, e.what());
break;
}
@ -266,45 +317,4 @@ void Consumer::StopConsuming() {
if (thread_.joinable()) thread_.join();
}
utils::BasicResult<std::string, std::vector<Message>> Consumer::GetBatch() {
std::vector<Message> batch{};
int64_t batch_size = info_.batch_size.value_or(kDefaultBatchSize);
batch.reserve(batch_size);
auto remaining_timeout_in_ms = info_.batch_interval.value_or(kDefaultBatchInterval).count();
auto start = std::chrono::steady_clock::now();
bool run_batch = true;
for (int64_t i = 0; remaining_timeout_in_ms > 0 && i < batch_size && is_running_.load(); ++i) {
std::unique_ptr<RdKafka::Message> msg(consumer_->consume(remaining_timeout_in_ms));
switch (msg->err()) {
case RdKafka::ERR__TIMED_OUT:
run_batch = false;
break;
case RdKafka::ERR_NO_ERROR:
batch.emplace_back(std::move(msg));
break;
default:
auto error = msg->errstr();
spdlog::warn("Unexpected error while consuming message in consumer {}, error: {}!", info_.consumer_name,
msg->errstr());
return {std::move(error)};
}
if (!run_batch) {
break;
}
auto now = std::chrono::steady_clock::now();
auto took = std::chrono::duration_cast<std::chrono::milliseconds>(now - start);
remaining_timeout_in_ms = remaining_timeout_in_ms - took.count();
start = now;
}
return {std::move(batch)};
}
} // namespace integrations::kafka

View File

@ -82,7 +82,7 @@ class Consumer final : public RdKafka::EventCb {
///
/// @throws ConsumerFailedToInitializeException if the consumer can't connect
/// to the Kafka endpoint.
explicit Consumer(const std::string &bootstrap_servers, ConsumerInfo info, ConsumerFunction consumer_function);
Consumer(const std::string &bootstrap_servers, ConsumerInfo info, ConsumerFunction consumer_function);
~Consumer() override;
Consumer(const Consumer &other) = delete;
@ -121,7 +121,7 @@ class Consumer final : public RdKafka::EventCb {
/// @param test_consumer_function a function to feed the received messages in, only used during this dry-run.
///
/// @throws ConsumerRunningException if the consumer is alredy running.
void Test(std::optional<int64_t> limit_batches, const ConsumerFunction &test_consumer_function);
void Test(std::optional<int64_t> limit_batches, const ConsumerFunction &test_consumer_function) const;
/// Returns true if the consumer is actively consuming messages.
bool IsRunning() const;
@ -135,11 +135,10 @@ class Consumer final : public RdKafka::EventCb {
void StopConsuming();
utils::BasicResult<std::string, std::vector<Message>> GetBatch();
ConsumerInfo info_;
ConsumerFunction consumer_function_;
mutable std::atomic<bool> is_running_{false};
mutable std::vector<RdKafka::TopicPartition *> last_assignment_; // Protected by is_running_
std::optional<int64_t> limit_batches_{std::nullopt};
std::unique_ptr<RdKafka::KafkaConsumer, std::function<void(RdKafka::KafkaConsumer *)>> consumer_;
std::thread thread_;

View File

@ -32,6 +32,12 @@ class ConsumerTestFailedException : public KafkaStreamException {
: KafkaStreamException("Kafka consumer {} test failed: {}", consumer_name, error) {}
};
class ConsumerStartFailedException : public KafkaStreamException {
public:
explicit ConsumerStartFailedException(const std::string &consumer_name, const std::string &error)
: KafkaStreamException("Starting Kafka consumer {} failed: {}", consumer_name, error) {}
};
class TopicNotFoundException : public KafkaStreamException {
public:
TopicNotFoundException(const std::string &consumer_name, const std::string &topic_name)

View File

@ -141,11 +141,6 @@ class UserModificationInMulticommandTxException : public QueryException {
: QueryException("Authentication clause not allowed in multicommand transactions.") {}
};
class StreamClauseInMulticommandTxException : public QueryException {
public:
StreamClauseInMulticommandTxException() : QueryException("Stream clause not allowed in multicommand transactions.") {}
};
class InvalidArgumentsException : public QueryException {
public:
InvalidArgumentsException(const std::string &argument_name, const std::string &message)
@ -176,6 +171,12 @@ class TriggerModificationInMulticommandTxException : public QueryException {
: QueryException("Trigger queries not allowed in multicommand transactions.") {}
};
class StreamQueryInMulticommandTxException : public QueryException {
public:
StreamQueryInMulticommandTxException()
: QueryException("Stream queries are not allowed in multicommand transactions.") {}
};
class IsolationLevelModificationInMulticommandTxException : public QueryException {
public:
IsolationLevelModificationInMulticommandTxException()

View File

@ -2193,7 +2193,7 @@ cpp<#
(:serialize))
(lcp:define-enum privilege
(create delete match merge set remove index stats auth constraint
dump replication durability read_file free_memory trigger config)
dump replication durability read_file free_memory trigger config stream)
(:serialize))
#>cpp
AuthQuery() = default;
@ -2233,7 +2233,7 @@ const std::vector<AuthQuery::Privilege> kPrivilegesAll = {
AuthQuery::Privilege::READ_FILE,
AuthQuery::Privilege::DURABILITY,
AuthQuery::Privilege::FREE_MEMORY, AuthQuery::Privilege::TRIGGER,
AuthQuery::Privilege::CONFIG};
AuthQuery::Privilege::CONFIG, AuthQuery::Privilege::STREAM};
cpp<#
(lcp:define-class info-query (query)
@ -2310,7 +2310,9 @@ cpp<#
(socket_address "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(port "Expression *" :initval "nullptr" :scope :public)
(port "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(sync_mode "SyncMode" :scope :public)
(timeout "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
@ -2464,4 +2466,33 @@ cpp<#
(:serialize (:slk))
(:clone))
(lcp:define-class stream-query (query)
((action "Action" :scope :public)
(stream_name "std::string" :scope :public)
(topic_names "std::vector<std::string>" :scope :public)
(transform_name "std::string" :scope :public)
(consumer_group "std::string" :scope :public)
(batch_interval "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(batch_size "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression")))
(:public
(lcp:define-enum action
(create-stream drop-stream start-stream stop-stream start-all-streams stop-all-streams show-streams test-stream)
(:serialize))
#>cpp
StreamQuery() = default;
DEFVISITABLE(QueryVisitor<void>);
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:pop-namespace) ;; namespace query

View File

@ -79,6 +79,7 @@ class FreeMemoryQuery;
class TriggerQuery;
class IsolationLevelQuery;
class CreateSnapshotQuery;
class StreamQuery;
using TreeCompositeVisitor = ::utils::CompositeVisitor<
SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, AdditionOperator,
@ -110,9 +111,9 @@ class ExpressionVisitor
None, ParameterLookup, Identifier, PrimitiveLiteral, RegexMatch> {};
template <class TResult>
class QueryVisitor : public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery,
InfoQuery, ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery,
FreeMemoryQuery, TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery> {
};
class QueryVisitor
: public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, IndexQuery, AuthQuery, InfoQuery,
ConstraintQuery, DumpQuery, ReplicationQuery, LockPathQuery, FreeMemoryQuery,
TriggerQuery, IsolationLevelQuery, CreateSnapshotQuery, StreamQuery> {};
} // namespace query

View File

@ -15,6 +15,7 @@
#include <climits>
#include <codecvt>
#include <cstring>
#include <iterator>
#include <limits>
#include <string>
#include <tuple>
@ -53,6 +54,16 @@ std::optional<std::pair<query::Expression *, size_t>> VisitMemoryLimit(
return std::make_pair(memory_limit, memory_scale);
}
std::string JoinSymbolicNames(antlr4::tree::ParseTreeVisitor *visitor,
const std::vector<MemgraphCypher::SymbolicNameContext *> &symbolicNames) {
std::vector<std::string> procedure_subnames;
procedure_subnames.reserve(symbolicNames.size());
for (auto *subname : symbolicNames) {
procedure_subnames.emplace_back(subname->accept(visitor).as<std::string>());
}
return utils::Join(procedure_subnames, ".");
}
} // namespace
antlrcpp::Any CypherMainVisitor::visitExplainQuery(MemgraphCypher::ExplainQueryContext *ctx) {
@ -447,6 +458,82 @@ antlrcpp::Any CypherMainVisitor::visitCreateSnapshotQuery(MemgraphCypher::Create
return query_;
}
antlrcpp::Any CypherMainVisitor::visitStreamQuery(MemgraphCypher::StreamQueryContext *ctx) {
MG_ASSERT(ctx->children.size() == 1, "StreamQuery should have exactly one child!");
auto *stream_query = ctx->children[0]->accept(this).as<StreamQuery *>();
query_ = stream_query;
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitCreateStream(MemgraphCypher::CreateStreamContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::CREATE_STREAM;
stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>();
auto *topic_names_ctx = ctx->topicNames();
MG_ASSERT(topic_names_ctx != nullptr);
auto topic_names = topic_names_ctx->symbolicNameWithDots();
MG_ASSERT(!topic_names.empty());
stream_query->topic_names_.reserve(topic_names.size());
std::transform(topic_names.begin(), topic_names.end(), std::back_inserter(stream_query->topic_names_),
[this](auto *topic_name) { return JoinSymbolicNames(this, topic_name->symbolicName()); });
stream_query->transform_name_ = JoinSymbolicNames(this, ctx->transformationName->symbolicName());
if (ctx->CONSUMER_GROUP()) {
stream_query->consumer_group_ = JoinSymbolicNames(this, ctx->consumerGroup->symbolicName());
}
if (ctx->BATCH_INTERVAL()) {
if (!ctx->batchInterval->numberLiteral() || !ctx->batchInterval->numberLiteral()->integerLiteral()) {
throw SemanticException("Batch interval should be an integer literal!");
}
stream_query->batch_interval_ = ctx->batchInterval->accept(this);
}
if (ctx->BATCH_SIZE()) {
if (!ctx->batchSize->numberLiteral() || !ctx->batchSize->numberLiteral()->integerLiteral()) {
throw SemanticException("Batch size should be an integer literal!");
}
stream_query->batch_size_ = ctx->batchSize->accept(this);
}
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitDropStream(MemgraphCypher::DropStreamContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::DROP_STREAM;
stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>();
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitStartStream(MemgraphCypher::StartStreamContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::START_STREAM;
stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>();
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitStartAllStreams(MemgraphCypher::StartAllStreamsContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::START_ALL_STREAMS;
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitStopStream(MemgraphCypher::StopStreamContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::STOP_STREAM;
stream_query->stream_name_ = ctx->streamName()->symbolicName()->accept(this).as<std::string>();
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitStopAllStreams(MemgraphCypher::StopAllStreamsContext *ctx) {
auto *stream_query = storage_->Create<StreamQuery>();
stream_query->action_ = StreamQuery::Action::STOP_ALL_STREAMS;
return stream_query;
}
antlrcpp::Any CypherMainVisitor::visitCypherUnion(MemgraphCypher::CypherUnionContext *ctx) {
bool distinct = !ctx->ALL();
auto *cypher_union = storage_->Create<CypherUnion>(distinct);
@ -616,12 +703,7 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedur
auto *call_proc = storage_->Create<CallProcedure>();
MG_ASSERT(!ctx->procedureName()->symbolicName().empty());
std::vector<std::string> procedure_subnames;
procedure_subnames.reserve(ctx->procedureName()->symbolicName().size());
for (auto *subname : ctx->procedureName()->symbolicName()) {
procedure_subnames.emplace_back(subname->accept(this).as<std::string>());
}
utils::Join(&call_proc->procedure_name_, procedure_subnames, ".");
call_proc->procedure_name_ = JoinSymbolicNames(this, ctx->procedureName()->symbolicName());
call_proc->arguments_.reserve(ctx->expression().size());
for (auto *expr : ctx->expression()) {
call_proc->arguments_.push_back(expr->accept(this));
@ -880,6 +962,7 @@ antlrcpp::Any CypherMainVisitor::visitPrivilege(MemgraphCypher::PrivilegeContext
if (ctx->TRIGGER()) return AuthQuery::Privilege::TRIGGER;
if (ctx->CONFIG()) return AuthQuery::Privilege::CONFIG;
if (ctx->DURABILITY()) return AuthQuery::Privilege::DURABILITY;
if (ctx->STREAM()) return AuthQuery::Privilege::STREAM;
LOG_FATAL("Should not get here - unknown privilege!");
}

View File

@ -248,6 +248,41 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitCreateSnapshotQuery(MemgraphCypher::CreateSnapshotQueryContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStreamQuery(MemgraphCypher::StreamQueryContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitCreateStream(MemgraphCypher::CreateStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitDropStream(MemgraphCypher::DropStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStartStream(MemgraphCypher::StartStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStartAllStreams(MemgraphCypher::StartAllStreamsContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStopStream(MemgraphCypher::StopStreamContext *ctx) override;
/**
* @return StreamQuery*
*/
antlrcpp::Any visitStopAllStreams(MemgraphCypher::StopAllStreamsContext *ctx) override;
/**
* @return CypherUnion*
*/

View File

@ -12,12 +12,16 @@ memgraphCypherKeyword : cypherKeyword
| ASYNC
| AUTH
| BAD
| BATCHES
| BATCH_INTERVAL
| BATCH_SIZE
| BEFORE
| CLEAR
| CONFIG
| CSV
| COMMIT
| COMMITTED
| CONFIG
| CONSUMER_GROUP
| CSV
| DATA
| DELIMITER
| DATABASE
@ -54,13 +58,18 @@ memgraphCypherKeyword : cypherKeyword
| QUOTE
| SESSION
| SNAPSHOT
| START
| STATS
| STREAM
| STREAMS
| SYNC
| TRANSACTION
| TRIGGER
| TRIGGERS
| TIMEOUT
| TO
| TOPICS
| TRANSACTION
| TRANSFORM
| TRIGGER
| TRIGGERS
| UNCOMMITTED
| UNLOCK
| UPDATE
@ -87,6 +96,7 @@ query : cypherQuery
| triggerQuery
| isolationLevelQuery
| createSnapshotQuery
| streamQuery
;
authQuery : createRole
@ -131,6 +141,14 @@ clause : cypherMatch
| loadCsv
;
streamQuery : createStream
| dropStream
| startStream
| startAllStreams
| stopStream
| stopAllStreams
;
loadCsv : LOAD CSV FROM csvFile ( WITH | NO ) HEADER
( IGNORE BAD ) ?
( DELIMITER delimiter ) ?
@ -189,6 +207,7 @@ privilege : CREATE
| TRIGGER
| CONFIG
| DURABILITY
| STREAM
;
privilegeList : privilege ( ',' privilege )* ;
@ -244,3 +263,26 @@ isolationLevelScope : GLOBAL | SESSION | NEXT ;
isolationLevelQuery : SET isolationLevelScope TRANSACTION ISOLATION LEVEL isolationLevel ;
createSnapshotQuery : CREATE SNAPSHOT ;
streamName : symbolicName ;
symbolicNameWithDots : symbolicName ( DOT symbolicName )* ;
topicNames : symbolicNameWithDots ( COMMA symbolicNameWithDots )* ;
createStream : CREATE STREAM streamName
TOPICS topicNames
TRANSFORM transformationName=symbolicNameWithDots
( CONSUMER_GROUP consumerGroup=symbolicNameWithDots ) ?
( BATCH_INTERVAL batchInterval=literal ) ?
( BATCH_SIZE batchSize=literal ) ? ;
dropStream : DROP STREAM streamName ;
startStream : START STREAM streamName ;
startAllStreams : START ALL STREAMS ;
stopStream : STOP STREAM streamName ;
stopAllStreams : STOP ALL STREAMS ;

View File

@ -17,11 +17,15 @@ ALTER : A L T E R ;
ASYNC : A S Y N C ;
AUTH : A U T H ;
BAD : B A D ;
BATCHES : B A T C H E S ;
BATCH_INTERVAL : B A T C H UNDERSCORE I N T E R V A L ;
BATCH_SIZE : B A T C H UNDERSCORE S I Z E ;
BEFORE : B E F O R E ;
CLEAR : C L E A R ;
COMMIT : C O M M I T ;
COMMITTED : C O M M I T T E D ;
CONFIG : C O N F I G ;
CONSUMER_GROUP : C O N S U M E R UNDERSCORE G R O U P ;
CSV : C S V ;
DATA : D A T A ;
DELIMITER : D E L I M I T E R ;
@ -65,11 +69,17 @@ ROLES : R O L E S ;
QUOTE : Q U O T E ;
SESSION : S E S S I O N ;
SNAPSHOT : S N A P S H O T ;
START : S T A R T ;
STATS : S T A T S ;
STOP : S T O P ;
STREAM : S T R E A M ;
STREAMS : S T R E A M S ;
SYNC : S Y N C ;
TIMEOUT : T I M E O U T ;
TO : T O ;
TOPICS : T O P I C S;
TRANSACTION : T R A N S A C T I O N ;
TRANSFORM : T R A N S F O R M ;
TRIGGER : T R I G G E R ;
TRIGGERS : T R I G G E R S ;
UNCOMMITTED : U N C O M M I T T E D ;

View File

@ -55,6 +55,8 @@ class PrivilegeExtractor : public QueryVisitor<void>, public HierarchicalTreeVis
void Visit(TriggerQuery &trigger_query) override { AddPrivilege(AuthQuery::Privilege::TRIGGER); }
void Visit(StreamQuery &stream_query) override { AddPrivilege(AuthQuery::Privilege::STREAM); }
void Visit(ReplicationQuery &replication_query) override { AddPrivilege(AuthQuery::Privilege::REPLICATION); }
void Visit(IsolationLevelQuery &isolation_level_query) override { AddPrivilege(AuthQuery::Privilege::CONFIG); }

View File

@ -78,20 +78,60 @@ class Trie {
const int kBitsetSize = 65536;
const trie::Trie kKeywords = {
"union", "all", "optional", "match", "unwind", "as", "merge", "on",
"create", "set", "detach", "delete", "remove", "with", "distinct", "return",
"order", "by", "skip", "limit", "ascending", "asc", "descending", "desc",
"where", "or", "xor", "and", "not", "in", "starts", "ends",
"contains", "is", "null", "case", "when", "then", "else", "end",
"count", "filter", "extract", "any", "none", "single", "true", "false",
"reduce", "coalesce", "user", "password", "alter", "drop", "show", "stats",
"unique", "explain", "profile", "storage", "index", "info", "exists", "assert",
"constraint", "node", "key", "dump", "database", "call", "yield", "memory",
"mb", "kb", "unlimited", "free", "procedure", "query", "free_memory", "read_file",
"lock_path", "after", "before", "execute", "transaction", "trigger", "triggers", "update",
"comitted", "uncomitted", "global", "isolation", "level", "next", "read", "session",
"snapshot", "transaction"};
const trie::Trie kKeywords = {"union", "all",
"optional", "match",
"unwind", "as",
"merge", "on",
"create", "set",
"detach", "delete",
"remove", "with",
"distinct", "return",
"order", "by",
"skip", "limit",
"ascending", "asc",
"descending", "desc",
"where", "or",
"xor", "and",
"not", "in",
"starts", "ends",
"contains", "is",
"null", "case",
"when", "then",
"else", "end",
"count", "filter",
"extract", "any",
"none", "single",
"true", "false",
"reduce", "coalesce",
"user", "password",
"alter", "drop",
"show", "stats",
"unique", "explain",
"profile", "storage",
"index", "info",
"exists", "assert",
"constraint", "node",
"key", "dump",
"database", "call",
"yield", "memory",
"mb", "kb",
"unlimited", "free",
"procedure", "query",
"free_memory", "read_file",
"lock_path", "after",
"before", "execute",
"transaction", "trigger",
"triggers", "update",
"comitted", "uncomitted",
"global", "isolation",
"level", "next",
"read", "session",
"snapshot", "transaction",
"batches", "batch_interval",
"batch_size", "consumer_group",
"start", "stream",
"streams", "transform",
"topics"};
// Unicode codepoints that are allowed at the start of the unescaped name.
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(

View File

@ -1,6 +1,7 @@
#include "query/interpreter.hpp"
#include <atomic>
#include <chrono>
#include <limits>
#include "glue/communication.hpp"
@ -73,6 +74,18 @@ TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluato
return expression ? expression->Accept(*eval) : TypedValue();
}
template <typename TResult>
std::optional<TResult> GetOptionalValue(query::Expression *expression, ExpressionEvaluator &evaluator) {
if (expression != nullptr) {
auto int_value = expression->Accept(evaluator);
MG_ASSERT(int_value.IsNull() || int_value.IsInt());
if (int_value.IsInt()) {
return TResult{int_value.ValueInt()};
}
}
return {};
};
class ReplQueryHandler final : public query::ReplicationQueryHandler {
public:
explicit ReplQueryHandler(storage::Storage *db) : db_(db) {}
@ -447,6 +460,82 @@ Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &
}
}
Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
evaluation_context.timestamp =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
.count();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
Callback callback;
switch (stream_query->action_) {
case StreamQuery::Action::CREATE_STREAM: {
constexpr std::string_view kDefaultConsumerGroup = "mg_consumer";
std::string consumer_group{stream_query->consumer_group_.empty() ? kDefaultConsumerGroup
: stream_query->consumer_group_};
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_,
topic_names = stream_query->topic_names_, consumer_group = std::move(consumer_group),
batch_interval =
GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator),
batch_size = GetOptionalValue<int64_t>(stream_query->batch_size_, evaluator),
transformation_name = stream_query->transform_name_]() mutable {
interpreter_context->streams.Create(stream_name, query::StreamInfo{.topics = std::move(topic_names),
.consumer_group = std::move(consumer_group),
.batch_interval = batch_interval,
.batch_size = batch_size,
.transformation_name = "transform.trans"});
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::START_STREAM: {
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() {
interpreter_context->streams.Start(stream_name);
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::START_ALL_STREAMS: {
callback.fn = [interpreter_context]() {
interpreter_context->streams.StartAll();
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::STOP_STREAM: {
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() {
interpreter_context->streams.Stop(stream_name);
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::STOP_ALL_STREAMS: {
callback.fn = [interpreter_context]() {
interpreter_context->streams.StopAll();
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::DROP_STREAM: {
callback.fn = [interpreter_context, stream_name = stream_query->stream_name_]() {
interpreter_context->streams.Drop(stream_name);
return std::vector<std::vector<TypedValue>>{};
};
return callback;
}
case StreamQuery::Action::SHOW_STREAMS:
case StreamQuery::Action::TEST_STREAM:
throw std::logic_error("not implemented");
}
}
// Struct for lazy pulling from a vector
struct PullPlanVector {
explicit PullPlanVector(std::vector<std::vector<TypedValue>> values) : values_(std::move(values)) {}
@ -1167,6 +1256,34 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, const bool in_explic
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, const bool in_explicit_transaction,
InterpreterContext *interpreter_context, DbAccessor *dba,
const std::map<std::string, storage::PropertyValue> &user_parameters) {
if (in_explicit_transaction) {
throw StreamQueryInMulticommandTxException();
}
auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query);
MG_ASSERT(stream_query);
auto callback = HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (UNLIKELY(!pull_plan)) {
pull_plan = std::make_shared<PullPlanVector>(callback_fn());
}
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
RWType::NONE};
// False positive report for the std::make_shared above
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
}
constexpr auto ToStorageIsolationLevel(const IsolationLevelQuery::IsolationLevel isolation_level) noexcept {
switch (isolation_level) {
case IsolationLevelQuery::IsolationLevel::SNAPSHOT_ISOLATION:
@ -1575,6 +1692,9 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
} else if (utils::Downcast<TriggerQuery>(parsed_query.query)) {
prepared_query = PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_,
&*execution_db_accessor_, params);
} else if (utils::Downcast<StreamQuery>(parsed_query.query)) {
prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_,
&*execution_db_accessor_, params);
} else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) {
prepared_query =
PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this);

View File

@ -4,19 +4,19 @@
#include <string_view>
#include <utility>
#include <spdlog/spdlog.h>
#include <json/json.hpp>
#include "query/interpreter.hpp"
#include "utils/on_scope_exit.hpp"
namespace query {
namespace {
utils::SkipList<StreamData>::Iterator GetStream(utils::SkipList<StreamData>::Accessor &accessor,
const std::string &stream_name) {
auto it = accessor.find(stream_name);
if (it == accessor.end()) {
throw StreamsException("Couldn't find stream '{}'", stream_name);
auto GetStream(auto &map, const std::string &stream_name) {
if (auto it = map.find(stream_name); it != map.end()) {
return it;
}
return it;
throw StreamsException("Couldn't find stream '{}'", stream_name);
}
} // namespace
@ -77,14 +77,6 @@ void from_json(const nlohmann::json &data, StreamStatus &status) {
data.at(kIsRunningKey).get_to(status.is_running);
}
bool operator==(const StreamData &lhs, const StreamData &rhs) { return lhs.name == rhs.name; }
// NOLINTNEXTLINE(modernize-use-nullptr)
bool operator<(const StreamData &lhs, const StreamData &rhs) { return lhs.name < rhs.name; }
bool operator==(const StreamData &stream, const std::string &stream_name) { return stream.name == stream_name; }
// NOLINTNEXTLINE(modernize-use-nullptr)
bool operator<(const StreamData &stream, const std::string &stream_name) { return stream.name < stream_name; }
Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_servers,
std::filesystem::path directory)
: interpreter_context_(interpreter_context),
@ -93,8 +85,8 @@ Streams::Streams(InterpreterContext *interpreter_context, std::string bootstrap_
void Streams::RestoreStreams() {
spdlog::info("Loading streams...");
auto accessor = streams_.access();
MG_ASSERT(accessor.size() == 0, "Cannot restore streams when some streams already exist!");
auto locked_streams_map = streams_.Lock();
MG_ASSERT(locked_streams_map->empty(), "Cannot restore streams when some streams already exist!");
for (const auto &[stream_name, stream_data] : storage_) {
const auto get_failed_message = [](const std::string_view stream_name, const std::string_view message,
@ -115,7 +107,10 @@ void Streams::RestoreStreams() {
MG_ASSERT(status.name == stream_name, "Expected stream name is '{}', but got '{}'", stream_name, status.name);
try {
CreateConsumer(accessor, stream_name, std::move(status.info), status.is_running, false);
auto it = CreateConsumer(*locked_streams_map, stream_name, std::move(status.info));
if (status.is_running) {
it->second.consumer->Lock()->Start();
}
} catch (const utils::BasicException &exception) {
spdlog::warn(get_failed_message(stream_name, "unexpected error", exception.what()));
}
@ -123,16 +118,28 @@ void Streams::RestoreStreams() {
}
void Streams::Create(const std::string &stream_name, StreamInfo info) {
auto accessor = streams_.access();
CreateConsumer(accessor, stream_name, std::move(info), false, true);
auto locked_streams = streams_.Lock();
auto it = CreateConsumer(*locked_streams, stream_name, std::move(info));
try {
Persist(CreateStatus(stream_name, it->second.transformation_name, *it->second.consumer->ReadLock()));
} catch (...) {
locked_streams->erase(it);
throw;
}
}
void Streams::Drop(const std::string &stream_name) {
auto accessor = streams_.access();
auto locked_streams = streams_.Lock();
if (!accessor.remove(stream_name)) {
throw StreamsException("Couldn't find stream '{}'", stream_name);
}
auto it = GetStream(*locked_streams, stream_name);
// streams_ is write locked, which means there is no access to it outside of this function, thus only the Test
// function can be executing with the consumer, nothing else.
// By acquiring the write lock here for the consumer, we make sure there is
// no running Test function for this consumer, therefore it can be erased.
it->second.consumer->Lock();
locked_streams->erase(it);
if (!storage_.Delete(stream_name)) {
throw StreamsException("Couldn't delete stream '{}' from persistent store!", stream_name);
@ -142,73 +149,76 @@ void Streams::Drop(const std::string &stream_name) {
}
void Streams::Start(const std::string &stream_name) {
auto accessor = streams_.access();
auto it = GetStream(accessor, stream_name);
auto locked_streams = streams_.Lock();
auto it = GetStream(*locked_streams, stream_name);
auto locked_consumer = it->consumer->Lock();
auto locked_consumer = it->second.consumer->Lock();
locked_consumer->Start();
Persist(CreateStatus(it->name, it->transformation_name, *locked_consumer));
Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer));
}
void Streams::Stop(const std::string &stream_name) {
auto accessor = streams_.access();
auto it = GetStream(accessor, stream_name);
auto locked_streams = streams_.Lock();
auto it = GetStream(*locked_streams, stream_name);
auto locked_consumer = it->consumer->Lock();
auto locked_consumer = it->second.consumer->Lock();
locked_consumer->Stop();
Persist(CreateStatus(it->name, it->transformation_name, *locked_consumer));
Persist(CreateStatus(stream_name, it->second.transformation_name, *locked_consumer));
}
void Streams::StartAll() {
for (auto &stream_data : streams_.access()) {
stream_data.consumer->WithLock([this, &stream_data](auto &consumer) {
if (!consumer.IsRunning()) {
consumer.Start();
Persist(CreateStatus(stream_data.name, stream_data.transformation_name, consumer));
}
});
for (auto locked_streams = streams_.Lock(); auto &[stream_name, stream_data] : *locked_streams) {
auto locked_consumer = stream_data.consumer->Lock();
if (!locked_consumer->IsRunning()) {
locked_consumer->Start();
Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer));
}
}
}
void Streams::StopAll() {
for (auto &stream_data : streams_.access()) {
stream_data.consumer->WithLock([this, &stream_data](auto &consumer) {
if (consumer.IsRunning()) {
consumer.Stop();
Persist(CreateStatus(stream_data.name, stream_data.transformation_name, consumer));
}
});
for (auto locked_streams = streams_.Lock(); auto &[stream_name, stream_data] : *locked_streams) {
auto locked_consumer = stream_data.consumer->Lock();
if (locked_consumer->IsRunning()) {
locked_consumer->Stop();
Persist(CreateStatus(stream_name, stream_data.transformation_name, *locked_consumer));
}
}
}
std::vector<StreamStatus> Streams::Show() const {
std::vector<StreamStatus> result;
{
for (const auto &stream_data : streams_.access()) {
// Create string
for (auto locked_streams = streams_.ReadLock(); const auto &[stream_name, stream_data] : *locked_streams) {
result.emplace_back(
CreateStatus(stream_data.name, stream_data.transformation_name, *stream_data.consumer->ReadLock()));
CreateStatus(stream_name, stream_data.transformation_name, *stream_data.consumer->ReadLock()));
}
}
return result;
}
TransformationResult Streams::Test(const std::string &stream_name, std::optional<int64_t> batch_limit) {
auto accessor = streams_.access();
auto it = GetStream(accessor, stream_name);
TransformationResult Streams::Test(const std::string &stream_name, std::optional<int64_t> batch_limit) const {
TransformationResult result;
auto consumer_function = [&result](const std::vector<Message> &messages) {
for (const auto &message : messages) {
// TODO(antaljanosbenjamin) Update the logic with using the transform from modules
const auto payload = message.Payload();
const std::string_view payload_as_string_view{payload.data(), payload.size()};
spdlog::info("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view);
result[fmt::format("CREATE (n:MESSAGE {{payload: '{}'}})", payload_as_string_view)] = "replace with params";
}
};
it->consumer->Lock()->Test(batch_limit, consumer_function);
// This depends on the fact that Drop will first acquire a write lock to the consumer, and erase it only after that
auto locked_consumer = [this, &stream_name] {
auto locked_streams = streams_.ReadLock();
auto it = GetStream(*locked_streams, stream_name);
return it->second.consumer->ReadLock();
}();
locked_consumer->Test(batch_limit, consumer_function);
return result;
}
@ -227,9 +237,9 @@ StreamStatus Streams::CreateStatus(const std::string &name, const std::string &t
consumer.IsRunning()};
}
void Streams::CreateConsumer(utils::SkipList<StreamData>::Accessor &accessor, const std::string &stream_name,
StreamInfo info, const bool start_consumer, const bool persist_consumer) {
if (accessor.contains(stream_name)) {
Streams::StreamsMap::iterator Streams::CreateConsumer(StreamsMap &map, const std::string &stream_name,
StreamInfo stream_info) {
if (map.contains(stream_name)) {
throw StreamsException{"Stream already exists with name '{}'", stream_name};
}
@ -255,26 +265,18 @@ void Streams::CreateConsumer(utils::SkipList<StreamData>::Accessor &accessor, co
ConsumerInfo consumer_info{
.consumer_name = stream_name,
.topics = std::move(info.topics),
.consumer_group = std::move(info.consumer_group),
.batch_interval = info.batch_interval,
.batch_size = info.batch_size,
.topics = std::move(stream_info.topics),
.consumer_group = std::move(stream_info.consumer_group),
.batch_interval = stream_info.batch_interval,
.batch_size = stream_info.batch_size,
};
auto consumer = std::make_unique<SynchronizedConsumer>(bootstrap_servers_, std::move(consumer_info),
std::move(consumer_function));
auto locked_consumer = consumer->Lock();
if (start_consumer) {
locked_consumer->Start();
}
if (persist_consumer) {
Persist(CreateStatus(stream_name, info.transformation_name, *locked_consumer));
}
auto insert_result =
accessor.insert(StreamData{stream_name, std::move(info.transformation_name), std::move(consumer)});
auto insert_result = map.insert_or_assign(
stream_name, StreamData{std::move(stream_info.transformation_name),
std::make_unique<SynchronizedConsumer>(bootstrap_servers_, std::move(consumer_info),
std::move(consumer_function))});
MG_ASSERT(insert_result.second, "Unexpected error during storing consumer '{}'", stream_name);
return insert_result.first;
}
void Streams::Persist(StreamStatus &&status) {

View File

@ -10,7 +10,6 @@
#include "kvstore/kvstore.hpp"
#include "utils/exceptions.hpp"
#include "utils/rw_lock.hpp"
#include "utils/skip_list.hpp"
#include "utils/synchronized.hpp"
namespace query {
@ -42,7 +41,6 @@ struct StreamStatus {
using SynchronizedConsumer = utils::Synchronized<integrations::kafka::Consumer, utils::WritePrioritizedRWLock>;
struct StreamData {
std::string name;
// TODO(antaljanosbenjamin) How to reference the transformation in a better way?
std::string transformation_name;
// TODO(antaljanosbenjamin) consider propagate_const
@ -128,14 +126,16 @@ class Streams final {
/// @throws StreamsException if the stream doesn't exist
/// @throws ConsumerRunningException if the consumer is alredy running
/// @throws ConsumerTestFailedException if the transformation function throws any std::exception during processing
TransformationResult Test(const std::string &stream_name, std::optional<int64_t> batch_limit = std::nullopt);
TransformationResult Test(const std::string &stream_name, std::optional<int64_t> batch_limit = std::nullopt) const;
private:
using StreamsMap = std::unordered_map<std::string, StreamData>;
using SynchronizedStreamsMap = utils::Synchronized<StreamsMap, utils::WritePrioritizedRWLock>;
static StreamStatus CreateStatus(const std::string &name, const std::string &transformation_name,
const integrations::kafka::Consumer &consumer);
void CreateConsumer(utils::SkipList<StreamData>::Accessor &accessor, const std::string &stream_name,
StreamInfo stream_info, const bool start_consumer, const bool persist_consumer);
StreamsMap::iterator CreateConsumer(StreamsMap &map, const std::string &stream_name, StreamInfo stream_info);
void Persist(StreamStatus &&status);
@ -143,7 +143,7 @@ class Streams final {
std::string bootstrap_servers_;
kvstore::KVStore storage_;
utils::SkipList<StreamData> streams_;
SynchronizedStreamsMap streams_;
};
} // namespace query

View File

@ -2067,6 +2067,8 @@ TEST_P(CypherMainVisitorTest, GrantPrivilege) {
{AuthQuery::Privilege::TRIGGER});
check_auth_query(&ast_generator, "GRANT CONFIG TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {},
{AuthQuery::Privilege::CONFIG});
check_auth_query(&ast_generator, "GRANT STREAM TO user", AuthQuery::Action::GRANT_PRIVILEGE, "", "", "user", {},
{AuthQuery::Privilege::STREAM});
}
TEST_P(CypherMainVisitorTest, DenyPrivilege) {
@ -3062,8 +3064,9 @@ TEST_P(CypherMainVisitorTest, MemoryLimit) {
}
namespace {
template <typename TException = SyntaxException>
void TestInvalidQuery(const auto &query, Base &ast_generator) {
ASSERT_THROW(ast_generator.ParseQuery(query), SyntaxException);
EXPECT_THROW(ast_generator.ParseQuery(query), TException) << query;
}
} // namespace
@ -3203,4 +3206,168 @@ TEST_P(CypherMainVisitorTest, CreateSnapshotQuery) {
auto &ast_generator = *GetParam();
ASSERT_TRUE(dynamic_cast<CreateSnapshotQuery *>(ast_generator.ParseQuery("CREATE SNAPSHOT")));
}
void ValidateMostlyEmptyStreamQuery(Base &ast_generator, const std::string &query_string,
const StreamQuery::Action action, const std::string_view stream_name) {
auto *parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string));
ASSERT_NE(parsed_query, nullptr);
EXPECT_EQ(parsed_query->action_, action);
EXPECT_EQ(parsed_query->stream_name_, stream_name);
EXPECT_TRUE(parsed_query->topic_names_.empty());
EXPECT_TRUE(parsed_query->transform_name_.empty());
EXPECT_TRUE(parsed_query->consumer_group_.empty());
EXPECT_EQ(parsed_query->batch_interval_, nullptr);
EXPECT_EQ(parsed_query->batch_size_, nullptr);
}
TEST_P(CypherMainVisitorTest, DropStream) {
auto &ast_generator = *GetParam();
TestInvalidQuery("DROP ST", ast_generator);
TestInvalidQuery("DROP STREAM", ast_generator);
TestInvalidQuery("DROP STREAMS", ast_generator);
ValidateMostlyEmptyStreamQuery(ast_generator, "DrOP STREAm droppedStream", StreamQuery::Action::DROP_STREAM,
"droppedStream");
}
TEST_P(CypherMainVisitorTest, StartStream) {
auto &ast_generator = *GetParam();
TestInvalidQuery("START ST", ast_generator);
TestInvalidQuery("START STREAM", ast_generator);
TestInvalidQuery("START STREAMS", ast_generator);
ValidateMostlyEmptyStreamQuery(ast_generator, "START STREAM startedStream", StreamQuery::Action::START_STREAM,
"startedStream");
}
TEST_P(CypherMainVisitorTest, StartAllStreams) {
auto &ast_generator = *GetParam();
TestInvalidQuery("START ALL", ast_generator);
TestInvalidQuery("START ALL STREAM", ast_generator);
TestInvalidQuery("START STREAMS ALL", ast_generator);
ValidateMostlyEmptyStreamQuery(ast_generator, "StARt AlL StrEAMS", StreamQuery::Action::START_ALL_STREAMS, "");
}
TEST_P(CypherMainVisitorTest, StopStream) {
auto &ast_generator = *GetParam();
TestInvalidQuery("STOP ST", ast_generator);
TestInvalidQuery("STOP STREAM", ast_generator);
TestInvalidQuery("STOP STREAMS", ast_generator);
TestInvalidQuery("STOP STREAM invalid stream name", ast_generator);
ValidateMostlyEmptyStreamQuery(ast_generator, "STOP stREAM stoppedStream", StreamQuery::Action::STOP_STREAM,
"stoppedStream");
}
TEST_P(CypherMainVisitorTest, StopAllStreams) {
auto &ast_generator = *GetParam();
TestInvalidQuery("STOP ALL", ast_generator);
TestInvalidQuery("STOP ALL STREAM", ast_generator);
TestInvalidQuery("STOP STREAMS ALL", ast_generator);
ValidateMostlyEmptyStreamQuery(ast_generator, "SToP ALL STReaMS", StreamQuery::Action::STOP_ALL_STREAMS, "");
}
void ValidateCreateStreamQuery(Base &ast_generator, const std::string &query_string, const std::string_view stream_name,
const std::vector<std::string> &topic_names, const std::string_view transform_name,
const std::string_view consumer_group, const std::optional<TypedValue> &batch_interval,
const std::optional<TypedValue> &batch_size) {
StreamQuery *parsed_query{nullptr};
ASSERT_NO_THROW(parsed_query = dynamic_cast<StreamQuery *>(ast_generator.ParseQuery(query_string))) << query_string;
ASSERT_NE(parsed_query, nullptr);
EXPECT_EQ(parsed_query->stream_name_, stream_name);
auto check_expression = [&](Expression *expression, const std::optional<TypedValue> &expected) {
EXPECT_EQ(expression != nullptr, expected.has_value());
if (expected.has_value()) {
EXPECT_NO_FATAL_FAILURE(ast_generator.CheckLiteral(expression, *expected));
}
};
EXPECT_EQ(parsed_query->topic_names_, topic_names);
EXPECT_EQ(parsed_query->transform_name_, transform_name);
EXPECT_EQ(parsed_query->consumer_group_, consumer_group);
EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_interval_, batch_interval));
EXPECT_NO_FATAL_FAILURE(check_expression(parsed_query->batch_size_, batch_size));
}
TEST_P(CypherMainVisitorTest, CreateStream) {
auto &ast_generator = *GetParam();
TestInvalidQuery("CREATE STREAM", ast_generator);
TestInvalidQuery("CREATE STREAM invalid stream name TOPICS topic1 TRANSFORM transform", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS invalid topic name TRANSFORM transform", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM invalid transform name", ast_generator);
TestInvalidQuery("CREATE STREAM stream TRANSFORM transform", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS TRANSFORM transform", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform CONSUMER_GROUP", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform CONSUMER_GROUP invalid consumer group",
ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INTERVAL", ast_generator);
TestInvalidQuery<SemanticException>(
"CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INTERVAL 'invalid interval'", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE", ast_generator);
TestInvalidQuery<SemanticException>(
"CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE 'invalid size'", ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE 2 BATCH_INTERVAL 3",
ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_INVERVAL 2 CONSUMER_GROUP Gru",
ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1 TRANSFORM transform BATCH_SIZE 2 CONSUMER_GROUP Gru",
ast_generator);
TestInvalidQuery("CREATE STREAM stream TOPICS topic1, TRANSFORM transform BATCH_SIZE 2 CONSUMER_GROUP Gru",
ast_generator);
const std::string topic_name1{"topic1_name.with_dot"};
const std::string topic_name2{"topic1_name.with_multiple.dots"};
auto check_topic_names = [&ast_generator](const std::vector<std::string> &topic_names) {
constexpr std::string_view kStreamName{"SomeSuperStream"};
constexpr std::string_view kTransformName{"moreAwesomeTransform"};
constexpr std::string_view kConsumerGroup{"ConsumerGru"};
constexpr int kBatchInterval = 324;
const TypedValue batch_interval_value{kBatchInterval};
constexpr int kBatchSize = 1;
const TypedValue batch_size_value{kBatchSize};
const auto topic_names_as_str = utils::Join(topic_names, ",");
ValidateCreateStreamQuery(
ast_generator,
fmt::format("CREATE STREAM {} TOPICS {} TRANSFORM {}", kStreamName, topic_names_as_str, kTransformName),
kStreamName, topic_names, kTransformName, "", std::nullopt, std::nullopt);
ValidateCreateStreamQuery(ast_generator,
fmt::format("CREATE STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {} ", kStreamName,
topic_names_as_str, kTransformName, kConsumerGroup),
kStreamName, topic_names, kTransformName, kConsumerGroup, std::nullopt, std::nullopt);
ValidateCreateStreamQuery(ast_generator,
fmt::format("CREATE STREAM {} TOPICS {} TRANSFORM {} BATCH_INTERVAL {}", kStreamName,
topic_names_as_str, kTransformName, kBatchInterval),
kStreamName, topic_names, kTransformName, "", batch_interval_value, std::nullopt);
ValidateCreateStreamQuery(ast_generator,
fmt::format("CREATE STREAM {} TOPICS {} TRANSFORM {} BATCH_SIZE {}", kStreamName,
topic_names_as_str, kTransformName, kBatchSize),
kStreamName, topic_names, kTransformName, "", std::nullopt, batch_size_value);
ValidateCreateStreamQuery(
ast_generator,
fmt::format("CREATE STREAM {} TOPICS {} TRANSFORM {} CONSUMER_GROUP {} BATCH_INTERVAL {} BATCH_SIZE {}",
kStreamName, topic_names_as_str, kTransformName, kConsumerGroup, kBatchInterval, kBatchSize),
kStreamName, topic_names, kTransformName, kConsumerGroup, batch_interval_value, batch_size_value);
};
EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name1}));
EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name2}));
EXPECT_NO_FATAL_FAILURE(check_topic_names({topic_name1, topic_name2}));
}
} // namespace

View File

@ -277,7 +277,7 @@ TEST_F(ConsumerTest, InvalidTopic) {
EXPECT_THROW(Consumer(cluster.Bootstraps(), std::move(info), kDummyConsumerFunction), TopicNotFoundException);
}
TEST_F(ConsumerTest, StartsFromPreviousOffset) {
TEST_F(ConsumerTest, DISABLED_StartsFromPreviousOffset) {
constexpr auto kBatchSize = 1;
auto info = CreateDefaultConsumerInfo();
info.batch_size = kBatchSize;
@ -295,20 +295,26 @@ TEST_F(ConsumerTest, StartsFromPreviousOffset) {
received_message_count = message_count;
};
// This test depends on CreateConsumer starts and stops the consumer, so the offset is stored
auto consumer = CreateConsumer(std::move(info), std::move(consumer_function));
ASSERT_FALSE(consumer->IsRunning());
{
// This test depends on CreateConsumer starts and stops the consumer, so the offset is stored
auto consumer = CreateConsumer(ConsumerInfo{info}, consumer_function);
ASSERT_FALSE(consumer->IsRunning());
}
auto send_and_consume_messages = [&](int batch_count) {
SCOPED_TRACE(fmt::format("Already received messages: {}", received_message_count.load()));
auto expected_total_messages = received_message_count + batch_count;
for (auto sent_messages = 0; sent_messages < batch_count; ++sent_messages) {
cluster.SeedTopic(kTopicName,
std::string_view{kMessagePrefix + std::to_string(received_message_count + sent_messages)});
}
auto expected_total_messages = received_message_count + batch_count;
auto consumer = std::make_unique<Consumer>(cluster.Bootstraps(), ConsumerInfo{info}, consumer_function);
ASSERT_FALSE(consumer->IsRunning());
consumer->Start();
const auto start = std::chrono::steady_clock::now();
ASSERT_TRUE(consumer->IsRunning());
constexpr auto kMaxWaitTime = std::chrono::seconds(5);
while (expected_total_messages != received_message_count.load() &&
@ -317,7 +323,7 @@ TEST_F(ConsumerTest, StartsFromPreviousOffset) {
}
// it is stopped because of limited batches
EXPECT_EQ(expected_total_messages, received_message_count);
consumer->Stop();
EXPECT_NO_THROW(consumer->Stop());
ASSERT_FALSE(consumer->IsRunning());
EXPECT_TRUE(expected_messages_received) << "Some unexpected message have been received";
};

View File

@ -165,3 +165,8 @@ TEST_F(TestPrivilegeExtractor, CreateSnapshotQuery) {
auto *query = storage.Create<CreateSnapshotQuery>();
EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::DURABILITY));
}
TEST_F(TestPrivilegeExtractor, StreamQuery) {
auto *query = storage.Create<StreamQuery>();
EXPECT_THAT(GetRequiredPrivileges(query), UnorderedElementsAre(AuthQuery::Privilege::STREAM));
}

View File

@ -105,23 +105,23 @@ TEST_F(StreamsTest, SimpleStreamManagement) {
streams_->Create(check_data.name, check_data.info);
EXPECT_NO_FATAL_FAILURE(CheckStreamStatus(check_data));
streams_->Start(check_data.name);
EXPECT_NO_THROW(streams_->Start(check_data.name));
check_data.is_running = true;
EXPECT_NO_FATAL_FAILURE(CheckStreamStatus(check_data));
streams_->StopAll();
EXPECT_NO_THROW(streams_->StopAll());
check_data.is_running = false;
EXPECT_NO_FATAL_FAILURE(CheckStreamStatus(check_data));
streams_->StartAll();
EXPECT_NO_THROW(streams_->StartAll());
check_data.is_running = true;
EXPECT_NO_FATAL_FAILURE(CheckStreamStatus(check_data));
streams_->Stop(check_data.name);
EXPECT_NO_THROW(streams_->Stop(check_data.name));
check_data.is_running = false;
EXPECT_NO_FATAL_FAILURE(CheckStreamStatus(check_data));
streams_->Drop(check_data.name);
EXPECT_NO_THROW(streams_->Drop(check_data.name));
EXPECT_TRUE(streams_->Show().empty());
}