diff --git a/src/auth/models.cpp b/src/auth/models.cpp index a15fb703f..42b94c44d 100644 --- a/src/auth/models.cpp +++ b/src/auth/models.cpp @@ -40,6 +40,8 @@ std::string PermissionToString(Permission permission) { return "CONSTRAINT"; case Permission::DUMP: return "DUMP"; + case Permission::REPLICATION: + return "REPLICATION"; case Permission::AUTH: return "AUTH"; } diff --git a/src/auth/models.hpp b/src/auth/models.hpp index 6fd880ddf..c03492a5f 100644 --- a/src/auth/models.hpp +++ b/src/auth/models.hpp @@ -20,6 +20,7 @@ enum class Permission : uint64_t { STATS = 0x00000080, CONSTRAINT = 0x00000100, DUMP = 0x00000200, + REPLICATION = 0x00000400, AUTH = 0x00010000, }; @@ -28,7 +29,7 @@ const std::vector<Permission> kPermissionsAll = { Permission::MATCH, Permission::CREATE, Permission::MERGE, Permission::DELETE, Permission::SET, Permission::REMOVE, Permission::INDEX, Permission::STATS, Permission::CONSTRAINT, - Permission::DUMP, Permission::AUTH}; + Permission::DUMP, Permission::AUTH, Permission::REPLICATION}; // Function that converts a permission to its string representation. std::string PermissionToString(Permission permission); diff --git a/src/glue/auth.cpp b/src/glue/auth.cpp index 52e7eed13..bb6a4f1ab 100644 --- a/src/glue/auth.cpp +++ b/src/glue/auth.cpp @@ -24,6 +24,8 @@ auth::Permission PrivilegeToPermission(query::AuthQuery::Privilege privilege) { return auth::Permission::CONSTRAINT; case query::AuthQuery::Privilege::DUMP: return auth::Permission::DUMP; + case query::AuthQuery::Privilege::REPLICATION: + return auth::Permission::REPLICATION; case query::AuthQuery::Privilege::AUTH: return auth::Permission::AUTH; } diff --git a/src/query/exceptions.hpp b/src/query/exceptions.hpp index 9dea0626c..db61eae88 100644 --- a/src/query/exceptions.hpp +++ b/src/query/exceptions.hpp @@ -167,4 +167,11 @@ class InvalidArgumentsException : public QueryException { argument_name, message)) {} }; +class ReplicationModificationInMulticommandTxException : public QueryException { + public: + ReplicationModificationInMulticommandTxException() + : QueryException( + "Replication clause not allowed in multicommand transactions.") {} +}; + } // namespace query diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 5114648a1..98e8abb22 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -2191,7 +2191,7 @@ cpp<# (:serialize)) (lcp:define-enum privilege (create delete match merge set remove index stats auth constraint - dump) + dump replication) (:serialize)) #>cpp AuthQuery() = default; @@ -2226,7 +2226,8 @@ const std::vector<AuthQuery::Privilege> kPrivilegesAll = { AuthQuery::Privilege::SET, AuthQuery::Privilege::REMOVE, AuthQuery::Privilege::INDEX, AuthQuery::Privilege::STATS, AuthQuery::Privilege::AUTH, - AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP}; + AuthQuery::Privilege::CONSTRAINT, AuthQuery::Privilege::DUMP, + AuthQuery::Privilege::REPLICATION}; cpp<# (lcp:define-class info-query (query) @@ -2296,4 +2297,39 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class replication-query (query) + ((action "Action" :scope :public) + (mode "ReplicationMode" :scope :public) + (replica_name "std::string" :scope :public) + (hostname "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression")) + (sync_mode "SyncMode" :scope :public) + (timeout "Expression *" :initval "nullptr" :scope :public + :slk-save #'slk-save-ast-pointer + :slk-load (slk-load-ast-pointer "Expression"))) + + (:public + (lcp:define-enum action + (set-replication-mode show-replication-mode create-replica + drop-replica show-replicas) + (:serialize)) + (lcp:define-enum replication-mode + (main replica) + (:serialize)) + (lcp:define-enum sync-mode + (sync async) + (:serialize)) + #>cpp + ReplicationQuery() = default; + + DEFVISITABLE(QueryVisitor<void>); + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:pop-namespace) ;; namespace query diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 3eb6632db..3e5fa61d8 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -72,6 +72,7 @@ class InfoQuery; class ConstraintQuery; class RegexMatch; class DumpQuery; +class ReplicationQuery; using TreeCompositeVisitor = ::utils::CompositeVisitor< SingleQuery, CypherUnion, NamedExpression, OrOperator, XorOperator, @@ -115,7 +116,7 @@ class ExpressionVisitor template <class TResult> class QueryVisitor : public ::utils::Visitor<TResult, CypherQuery, ExplainQuery, ProfileQuery, - IndexQuery, AuthQuery, InfoQuery, - ConstraintQuery, DumpQuery> {}; + IndexQuery, AuthQuery, InfoQuery, ConstraintQuery, + DumpQuery, ReplicationQuery> {}; } // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 499b3a7ac..ac2657c4b 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -196,6 +196,78 @@ antlrcpp::Any CypherMainVisitor::visitDumpQuery( return dump_query; } +antlrcpp::Any CypherMainVisitor::visitReplicationQuery( + MemgraphCypher::ReplicationQueryContext *ctx) { + CHECK(ctx->children.size() == 1) + << "ReplicationQuery should have exactly one child!"; + auto *replication_query = + ctx->children[0]->accept(this).as<ReplicationQuery *>(); + query_ = replication_query; + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitSetReplicationMode( + MemgraphCypher::SetReplicationModeContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::SET_REPLICATION_MODE; + if (ctx->MAIN()) { + replication_query->mode_ = ReplicationQuery::ReplicationMode::MAIN; + } else if (ctx->REPLICA()) { + replication_query->mode_ = ReplicationQuery::ReplicationMode::REPLICA; + } + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowReplicationMode( + MemgraphCypher::ShowReplicationModeContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::SHOW_REPLICATION_MODE; + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitCreateReplica( + MemgraphCypher::CreateReplicaContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::CREATE_REPLICA; + replication_query->replica_name_ = + ctx->replicaName()->symbolicName()->accept(this).as<std::string>(); + if (ctx->SYNC()) { + replication_query->sync_mode_ = query::ReplicationQuery::SyncMode::SYNC; + } else if (ctx->ASYNC()) { + replication_query->sync_mode_ = query::ReplicationQuery::SyncMode::ASYNC; + } + if (!ctx->hostName()->literal()->StringLiteral()) { + throw SyntaxException("Hostname should be a string literal!"); + } else { + replication_query->hostname_ = ctx->hostName()->accept(this); + } + if (ctx->timeout) { + if (!ctx->timeout->numberLiteral()->doubleLiteral() && + !ctx->timeout->numberLiteral()->integerLiteral()) { + throw SyntaxException("Timeout should be a double literal!"); + } else { + replication_query->timeout_ = ctx->timeout->accept(this); + } + } + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitDropReplica( + MemgraphCypher::DropReplicaContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::DROP_REPLICA; + replication_query->replica_name_ = + ctx->replicaName()->symbolicName()->accept(this).as<std::string>(); + return replication_query; +} + +antlrcpp::Any CypherMainVisitor::visitShowReplicas( + MemgraphCypher::ShowReplicasContext *ctx) { + auto *replication_query = storage_->Create<ReplicationQuery>(); + replication_query->action_ = ReplicationQuery::Action::SHOW_REPLICAS; + return replication_query; +} + antlrcpp::Any CypherMainVisitor::visitCypherUnion( MemgraphCypher::CypherUnionContext *ctx) { bool distinct = !ctx->ALL(); diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index dc12edeaf..9e1d165ed 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -186,6 +186,42 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitDumpQuery(MemgraphCypher::DumpQueryContext *ctx) override; + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitReplicationQuery( + MemgraphCypher::ReplicationQueryContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitSetReplicationMode( + MemgraphCypher::SetReplicationModeContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitShowReplicationMode( + MemgraphCypher::ShowReplicationModeContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitCreateReplica( + MemgraphCypher::CreateReplicaContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitDropReplica( + MemgraphCypher::DropReplicaContext *ctx) override; + + /** + * @return ReplicationQuery* + */ + antlrcpp::Any visitShowReplicas( + MemgraphCypher::ShowReplicasContext *ctx) override; + /** * @return CypherUnion* */ diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 index 3d484b055..f35e0cc8d 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypher.g4 @@ -8,6 +8,7 @@ import Cypher ; memgraphCypherKeyword : cypherKeyword | ALTER + | ASYNC | AUTH | CLEAR | DATABASE @@ -18,12 +19,19 @@ memgraphCypherKeyword : cypherKeyword | FROM | GRANT | IDENTIFIED + | MAIN + | MODE | PASSWORD | PRIVILEGES + | REPLICA + | REPLICAS + | REPLICATION | REVOKE | ROLE | ROLES | STATS + | SYNC + | TIMEOUT | TO | USER | USERS @@ -42,6 +50,7 @@ query : cypherQuery | constraintQuery | authQuery | dumpQuery + | replicationQuery ; authQuery : createRole @@ -61,6 +70,13 @@ authQuery : createRole | showUsersForRole ; +replicationQuery : setReplicationMode + | showReplicationMode + | createReplica + | dropReplica + | showReplicas + ; + userOrRoleName : symbolicName ; createRole : CREATE ROLE role=userOrRoleName ; @@ -100,3 +116,19 @@ showRoleForUser : SHOW ROLE FOR user=userOrRoleName ; showUsersForRole : SHOW USERS FOR role=userOrRoleName ; dumpQuery: DUMP DATABASE ; + +setReplicationMode : SET REPLICATION MODE TO ( MAIN | REPLICA ) ; + +showReplicationMode : SHOW REPLICATION MODE ; + +replicaName : symbolicName ; + +hostName : literal ; + +createReplica : CREATE REPLICA replicaName ( SYNC | ASYNC ) + ( WITH TIMEOUT timeout=literal ) ? + TO hostName ; + +dropReplica : DROP REPLICA replicaName ; + +showReplicas : SHOW REPLICAS ; diff --git a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 index e191f148c..e7eb6766f 100644 --- a/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/MemgraphCypherLexer.g4 @@ -11,6 +11,7 @@ lexer grammar MemgraphCypherLexer ; import CypherLexer ; ALTER : A L T E R ; +ASYNC : A S Y N C ; AUTH : A U T H ; CLEAR : C L E A R ; DATABASE : D A T A B A S E ; @@ -22,12 +23,19 @@ FROM : F R O M ; GRANT : G R A N T ; GRANTS : G R A N T S ; IDENTIFIED : I D E N T I F I E D ; +MAIN : M A I N ; +MODE : M O D E ; PASSWORD : P A S S W O R D ; PRIVILEGES : P R I V I L E G E S ; +REPLICA : R E P L I C A ; +REPLICAS : R E P L I C A S ; +REPLICATION : R E P L I C A T I O N ; REVOKE : R E V O K E ; ROLE : R O L E ; ROLES : R O L E S ; STATS : S T A T S ; +SYNC : S Y N C ; +TIMEOUT : T I M E O U T ; TO : T O ; USER : U S E R ; USERS : U S E R S ; diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index d2d13ad81..9e1564329 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -59,6 +59,26 @@ class PrivilegeExtractor : public QueryVisitor<void>, AddPrivilege(AuthQuery::Privilege::DUMP); } + void Visit(ReplicationQuery &replication_query) override { + switch (replication_query.action_) { + case ReplicationQuery::Action::SET_REPLICATION_MODE: + AddPrivilege(AuthQuery::Privilege::REPLICATION); + break; + case ReplicationQuery::Action::SHOW_REPLICATION_MODE: + AddPrivilege(AuthQuery::Privilege::REPLICATION); + break; + case ReplicationQuery::Action::CREATE_REPLICA: + AddPrivilege(AuthQuery::Privilege::REPLICATION); + break; + case ReplicationQuery::Action::DROP_REPLICA: + AddPrivilege(AuthQuery::Privilege::REPLICATION); + break; + case ReplicationQuery::Action::SHOW_REPLICAS: + AddPrivilege(AuthQuery::Privilege::REPLICATION); + break; + } + } + bool PreVisit(Create &) override { AddPrivilege(AuthQuery::Privilege::CREATE); return false; diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 301c64f63..6e5c1b48c 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -323,6 +323,121 @@ Callback HandleAuthQuery(AuthQuery *auth_query, AuthQueryHandler *auth, } } +Callback HandleReplicationQuery(ReplicationQuery *repl_query, + ReplicationQueryHandler *handler, + const Parameters ¶meters, + DbAccessor *db_accessor) { + Frame frame(0); + SymbolTable symbol_table; + EvaluationContext evaluation_context; + evaluation_context.timestamp = + std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + evaluation_context.parameters = parameters; + ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, + db_accessor, storage::View::OLD); + + Callback callback; + switch (repl_query->action_) { + case ReplicationQuery::Action::SET_REPLICATION_MODE: { + callback.fn = [handler, mode = repl_query->mode_] { + if (!handler->SetReplicationMode(mode)) { + throw QueryRuntimeException( + "Couldn't set the desired replication mode."); + } + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + } + case ReplicationQuery::Action::SHOW_REPLICATION_MODE: { + callback.header = {"replication mode"}; + callback.fn = [handler] { + auto mode = handler->ShowReplicationMode(); + switch (mode) { + case ReplicationQuery::ReplicationMode::MAIN: { + return std::vector<std::vector<TypedValue>>{{TypedValue("main")}}; + } + case ReplicationQuery::ReplicationMode::REPLICA: { + return std::vector<std::vector<TypedValue>>{ + {TypedValue("replica")}}; + } + } + }; + return callback; + } + case ReplicationQuery::Action::CREATE_REPLICA: { + const auto &name = repl_query->replica_name_; + const auto &sync_mode = repl_query->sync_mode_; + auto hostname = + EvaluateOptionalExpression(repl_query->hostname_, &evaluator); + auto timeout = + EvaluateOptionalExpression(repl_query->timeout_, &evaluator); + std::optional<double> opt_timeout; + if (timeout.IsDouble()) { + opt_timeout = timeout.ValueDouble(); + } else if (timeout.IsInt()) { + opt_timeout = static_cast<double>(timeout.ValueInt()); + } + callback.fn = [handler, name, hostname, sync_mode, opt_timeout] { + CHECK(hostname.IsString()); + if (!handler->CreateReplica(name, std::string(hostname.ValueString()), + sync_mode, opt_timeout)) { + throw QueryRuntimeException( + "Couldn't create the desired replica '{}'.", name); + } + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + } + case ReplicationQuery::Action::DROP_REPLICA: { + const auto &name = repl_query->replica_name_; + callback.fn = [handler, name] { + if (!handler->DropReplica(name)) { + throw QueryRuntimeException("Couldn't drop the replica '{}'.", name); + } + return std::vector<std::vector<TypedValue>>(); + }; + return callback; + } + case ReplicationQuery::Action::SHOW_REPLICAS: { + callback.header = {"name", "hostname", "sync_mode", "timeout"}; + callback.fn = [handler, replica_nfields = callback.header.size()] { + const auto &replicas = handler->ShowReplicas(); + auto typed_replicas = std::vector<std::vector<TypedValue>>{}; + typed_replicas.reserve(replicas.size()); + for (auto &replica : replicas) { + std::vector<TypedValue> typed_replica; + typed_replica.reserve(replica_nfields); + + typed_replica.emplace_back(TypedValue(replica.name)); + typed_replica.emplace_back(TypedValue(replica.hostname)); + switch (replica.sync_mode) { + case ReplicationQuery::SyncMode::SYNC: + typed_replica.emplace_back(TypedValue("sync")); + break; + case ReplicationQuery::SyncMode::ASYNC: + typed_replica.emplace_back(TypedValue("async")); + break; + } + typed_replica.emplace_back( + TypedValue(static_cast<int64_t>(replica.sync_mode))); + if (replica.timeout) { + typed_replica.emplace_back(TypedValue(*replica.timeout)); + } else { + typed_replica.emplace_back(TypedValue()); + } + + typed_replicas.emplace_back(std::move(typed_replica)); + } + return typed_replicas; + }; + return callback; + } + return callback; + } +} + Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_context_(interpreter_context) { CHECK(interpreter_context_) << "Interpreter context must not be NULL"; @@ -896,6 +1011,50 @@ PreparedQuery PrepareAuthQuery( }}; } +PreparedQuery PrepareReplicationQuery( + ParsedQuery parsed_query, bool in_explicit_transaction, + std::map<std::string, TypedValue> *summary, + InterpreterContext *interpreter_context, DbAccessor *dba, + utils::MonotonicBufferResource *execution_memory) { + if (in_explicit_transaction) { + throw ReplicationModificationInMulticommandTxException(); + } + + auto *replication_query = + utils::Downcast<ReplicationQuery>(parsed_query.query); + auto callback = + HandleReplicationQuery(replication_query, interpreter_context->repl, + parsed_query.parameters, dba); + + SymbolTable symbol_table; + std::vector<Symbol> output_symbols; + for (const auto &column : callback.header) { + output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false")); + } + + auto plan = + std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>( + std::make_unique<plan::OutputTable>( + output_symbols, + [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), + 0.0, AstStorage{}, symbol_table)); + auto pull_plan = + std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, + interpreter_context, execution_memory); + return PreparedQuery{ + callback.header, std::move(parsed_query.required_privileges), + [pull_plan = std::move(pull_plan), callback = std::move(callback), + output_symbols = std::move(output_symbols), + summary](AnyStream *stream, + std::optional<int> n) -> std::optional<QueryHandlerResult> { + if (pull_plan->Pull(stream, n, output_symbols, summary)) { + return callback.should_abort_query ? QueryHandlerResult::ABORT + : QueryHandlerResult::COMMIT; + } + return std::nullopt; + }}; +} + PreparedQuery PrepareInfoQuery( ParsedQuery parsed_query, bool in_explicit_transaction, std::map<std::string, TypedValue> *summary, @@ -1276,6 +1435,11 @@ Interpreter::PrepareResult Interpreter::Prepare( std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, &query_execution->execution_memory); + } else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) { + prepared_query = PrepareReplicationQuery( + std::move(parsed_query), in_explicit_transaction_, + &query_execution->summary, interpreter_context_, + &*execution_db_accessor_, &query_execution->execution_memory); } else { LOG(FATAL) << "Should not get here -- unknown query type!"; } diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index cd0ed0d56..220ba2c57 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -98,6 +98,45 @@ class AuthQueryHandler { enum class QueryHandlerResult { COMMIT, ABORT, NOTHING }; +class ReplicationQueryHandler { + public: + ReplicationQueryHandler() = default; + virtual ~ReplicationQueryHandler() = default; + + ReplicationQueryHandler(const ReplicationQueryHandler &) = delete; + ReplicationQueryHandler(ReplicationQueryHandler &&) = delete; + ReplicationQueryHandler &operator=(const ReplicationQueryHandler &) = delete; + ReplicationQueryHandler &operator=(ReplicationQueryHandler &&) = delete; + + struct Replica { + std::string name; + std::string hostname; + ReplicationQuery::SyncMode sync_mode; + std::optional<double> timeout; + }; + + /// @throw QueryRuntimeException if an error ocurred. + virtual bool SetReplicationMode( + ReplicationQuery::ReplicationMode replication_mode) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual ReplicationQuery::ReplicationMode ShowReplicationMode() const = 0; + + /// Return false if the replica already exists. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool CreateReplica(const std::string &name, + const std::string &hostname, + ReplicationQuery::SyncMode sync_mode, + std::optional<double> timeout) = 0; + + /// Return false if the replica doesn't exist. + /// @throw QueryRuntimeException if an error ocurred. + virtual bool DropReplica(const std::string &replica_name) = 0; + + /// @throw QueryRuntimeException if an error ocurred. + virtual std::vector<Replica> ShowReplicas() const = 0; +}; + /** * A container for data related to the preparation of a query. */ @@ -205,6 +244,7 @@ struct InterpreterContext { double execution_timeout_sec{180.0}; AuthQueryHandler *auth{nullptr}; + ReplicationQueryHandler *repl{nullptr}; utils::SkipList<QueryCacheEntry> ast_cache; utils::SkipList<PlanCacheEntry> plan_cache; diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 469e11b9f..10800a998 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -2453,6 +2453,94 @@ TEST_P(CypherMainVisitorTest, ShowUsersForRole) { SyntaxException); } +void check_replication_query(Base *ast_generator, const ReplicationQuery *query, + const std::string name, + const std::optional<TypedValue> hostname, + const ReplicationQuery::SyncMode sync_mode, + const std::optional<TypedValue> timeout) { + EXPECT_EQ(query->replica_name_, name); + EXPECT_EQ(query->sync_mode_, sync_mode); + ASSERT_EQ(static_cast<bool>(query->hostname_), static_cast<bool>(hostname)); + if (hostname) { + ast_generator->CheckLiteral(query->hostname_, *hostname); + } + ASSERT_EQ(static_cast<bool>(query->timeout_), static_cast<bool>(timeout)); + if (timeout) { + ast_generator->CheckLiteral(query->timeout_, *timeout); + } +} + +TEST_P(CypherMainVisitorTest, TestShowReplicationMode) { + auto &ast_generator = *GetParam(); + std::string raw_query = "SHOW REPLICATION MODE"; + auto *parsed_query = + dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(raw_query)); + EXPECT_EQ(parsed_query->action_, + ReplicationQuery::Action::SHOW_REPLICATION_MODE); +} + +TEST_P(CypherMainVisitorTest, TestShowReplicasQuery) { + auto &ast_generator = *GetParam(); + std::string raw_query = "SHOW REPLICAS"; + auto *parsed_query = + dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(raw_query)); + EXPECT_EQ(parsed_query->action_, ReplicationQuery::Action::SHOW_REPLICAS); +} + +TEST_P(CypherMainVisitorTest, TestSetReplicationMode) { + auto &ast_generator = *GetParam(); + std::string missing_mode_query = "SET REPLICATION MODE"; + ASSERT_THROW(ast_generator.ParseQuery(missing_mode_query), SyntaxException); + + std::string bad_mode_query = "SET REPLICATION MODE TO BUTTERY"; + ASSERT_THROW(ast_generator.ParseQuery(bad_mode_query), SyntaxException); + + std::string full_query = "SET REPLICATION MODE TO MAIN"; + auto *parsed_full_query = + dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(full_query)); + EXPECT_EQ(parsed_full_query->action_, + ReplicationQuery::Action::SET_REPLICATION_MODE); + EXPECT_EQ(parsed_full_query->mode_, ReplicationQuery::ReplicationMode::MAIN); +} + +TEST_P(CypherMainVisitorTest, TestCreateReplicationQuery) { + auto &ast_generator = *GetParam(); + + std::string faulty_query = "CREATE REPLICA WITH TIMEOUT TO"; + ASSERT_THROW(ast_generator.ParseQuery(faulty_query), SyntaxException); + + std::string no_timeout_query = + R"(CREATE REPLICA replica1 SYNC TO "127.0.0.1")"; + auto *no_timeout_query_parsed = dynamic_cast<ReplicationQuery *>( + ast_generator.ParseQuery(no_timeout_query)); + ASSERT_TRUE(no_timeout_query_parsed); + check_replication_query(&ast_generator, no_timeout_query_parsed, "replica1", + TypedValue("127.0.0.1"), + ReplicationQuery::SyncMode::SYNC, {}); + + std::string full_query = + R"(CREATE REPLICA replica2 SYNC WITH TIMEOUT 0.5 TO "1.1.1.1")"; + auto *full_query_parsed = + dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(full_query)); + ASSERT_TRUE(full_query_parsed); + check_replication_query(&ast_generator, full_query_parsed, "replica2", + TypedValue("1.1.1.1"), + ReplicationQuery::SyncMode::SYNC, TypedValue(0.5)); +} + +TEST_P(CypherMainVisitorTest, TestDeleteReplica) { + auto &ast_generator = *GetParam(); + + std::string missing_name_query = "DROP REPLICA"; + ASSERT_THROW(ast_generator.ParseQuery(missing_name_query), SyntaxException); + + std::string correct_query = "DROP REPLICA replica1"; + auto *correct_query_parsed = + dynamic_cast<ReplicationQuery *>(ast_generator.ParseQuery(correct_query)); + ASSERT_TRUE(correct_query_parsed); + EXPECT_EQ(correct_query_parsed->replica_name_, "replica1"); +} + TEST_P(CypherMainVisitorTest, TestExplainRegularQuery) { auto &ast_generator = *GetParam(); EXPECT_TRUE(dynamic_cast<ExplainQuery *>(