diff --git a/src/glue/auth_checker.cpp b/src/glue/auth_checker.cpp index 1c6b9ab9a..981ab8cca 100644 --- a/src/glue/auth_checker.cpp +++ b/src/glue/auth_checker.cpp @@ -94,20 +94,24 @@ std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGra return {}; } try { - auto locked_auth = auth_->Lock(); - if (username != user_.username()) { - auto maybe_user = locked_auth->GetUser(username); + auto user = user_.Lock(); + if (username != user->username()) { + auto maybe_user = auth_->ReadLock()->GetUser(username); if (!maybe_user) { throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username); } - user_ = std::move(*maybe_user); + *user = std::move(*maybe_user); } - return std::make_unique<memgraph::glue::FineGrainedAuthChecker>(user_, dba); + return std::make_unique<memgraph::glue::FineGrainedAuthChecker>(*user, dba); } catch (const memgraph::auth::AuthException &e) { throw memgraph::query::QueryRuntimeException(e.what()); } } + +void AuthChecker::ClearCache() const { + user_.WithLock([](auto &user) mutable { user = {}; }); +} #endif bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user, diff --git a/src/glue/auth_checker.hpp b/src/glue/auth_checker.hpp index 75c35eacf..e926c120b 100644 --- a/src/glue/auth_checker.hpp +++ b/src/glue/auth_checker.hpp @@ -16,6 +16,7 @@ #include "query/auth_checker.hpp" #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" +#include "utils/spin_lock.hpp" namespace memgraph::glue { @@ -32,6 +33,8 @@ class AuthChecker : public query::AuthChecker { std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker( const std::string &username, const memgraph::query::DbAccessor *dba) const override; + void ClearCache() const override; + #endif [[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user, const std::vector<memgraph::query::AuthQuery::Privilege> &privileges, @@ -39,7 +42,7 @@ class AuthChecker : public query::AuthChecker { private: memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_; - mutable auth::User user_; + mutable memgraph::utils::Synchronized<auth::User, memgraph::utils::SpinLock> user_; // cached user }; #ifdef MG_ENTERPRISE class FineGrainedAuthChecker : public query::FineGrainedAuthChecker { diff --git a/src/query/auth_checker.hpp b/src/query/auth_checker.hpp index cb1be8985..f64c16b1e 100644 --- a/src/query/auth_checker.hpp +++ b/src/query/auth_checker.hpp @@ -30,6 +30,8 @@ class AuthChecker { #ifdef MG_ENTERPRISE [[nodiscard]] virtual std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker( const std::string &username, const memgraph::query::DbAccessor *db_accessor) const = 0; + + virtual void ClearCache() const = 0; #endif }; #ifdef MG_ENTERPRISE @@ -103,6 +105,8 @@ class AllowEverythingAuthChecker final : public query::AuthChecker { const query::DbAccessor * /*dba*/) const override { return std::make_unique<AllowEverythingFineGrainedAuthChecker>(); } + + void ClearCache() const override {} #endif }; // namespace memgraph::query diff --git a/src/query/cypher_query_interpreter.cpp b/src/query/cypher_query_interpreter.cpp index 57f0f79e8..8759333bc 100644 --- a/src/query/cypher_query_interpreter.cpp +++ b/src/query/cypher_query_interpreter.cpp @@ -76,7 +76,7 @@ ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::stri // Convert the ANTLR4 parse tree into an AST. AstStorage ast_storage; - frontend::ParsingContext context{true}; + frontend::ParsingContext context{.is_query_cached = true}; frontend::CypherMainVisitor visitor(context, &ast_storage); visitor.visit(parser->tree()); diff --git a/src/query/interpret/eval.cpp b/src/query/interpret/eval.cpp index c60972234..8bd308420 100644 --- a/src/query/interpret/eval.cpp +++ b/src/query/interpret/eval.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -22,9 +22,10 @@ int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std: } } -std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale) { +std::optional<size_t> EvaluateMemoryLimit(ExpressionVisitor<TypedValue> &eval, Expression *memory_limit, + size_t memory_scale) { if (!memory_limit) return std::nullopt; - auto limit_value = memory_limit->Accept(*eval); + auto limit_value = memory_limit->Accept(eval); if (!limit_value.IsInt() || limit_value.ValueInt() <= 0) throw QueryRuntimeException("Memory limit must be a non-negative integer."); size_t limit = limit_value.ValueInt(); diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index b6c0e4675..1c78d6c9b 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -99,12 +99,79 @@ class ReferenceExpressionEvaluator : public ExpressionVisitor<TypedValue *> { UNSUCCESSFUL_VISIT(RegexMatch); UNSUCCESSFUL_VISIT(Exists); +#undef UNSUCCESSFUL_VISIT + private: Frame *frame_; const SymbolTable *symbol_table_; const EvaluationContext *ctx_; }; +class PrimitiveLiteralExpressionEvaluator : public ExpressionVisitor<TypedValue> { + public: + explicit PrimitiveLiteralExpressionEvaluator(EvaluationContext const &ctx) : ctx_(&ctx) {} + using ExpressionVisitor<TypedValue>::Visit; + TypedValue Visit(PrimitiveLiteral &literal) override { + // TODO: no need to evaluate constants, we can write it to frame in one + // of the previous phases. + return TypedValue(literal.value_, ctx_->memory); + } + TypedValue Visit(ParameterLookup ¶m_lookup) override { + return TypedValue(ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->memory); + } + +#define INVALID_VISIT(expr_name) \ + TypedValue Visit(expr_name & /*expr*/) override { \ + DLOG_FATAL("Invalid expression type visited with PrimitiveLiteralExpressionEvaluator."); \ + return {}; \ + } + + INVALID_VISIT(NamedExpression) + INVALID_VISIT(OrOperator) + INVALID_VISIT(XorOperator) + INVALID_VISIT(AndOperator) + INVALID_VISIT(NotOperator) + INVALID_VISIT(AdditionOperator) + INVALID_VISIT(SubtractionOperator) + INVALID_VISIT(MultiplicationOperator) + INVALID_VISIT(DivisionOperator) + INVALID_VISIT(ModOperator) + INVALID_VISIT(NotEqualOperator) + INVALID_VISIT(EqualOperator) + INVALID_VISIT(LessOperator) + INVALID_VISIT(GreaterOperator) + INVALID_VISIT(LessEqualOperator) + INVALID_VISIT(GreaterEqualOperator) + INVALID_VISIT(InListOperator) + INVALID_VISIT(SubscriptOperator) + INVALID_VISIT(ListSlicingOperator) + INVALID_VISIT(IfOperator) + INVALID_VISIT(UnaryPlusOperator) + INVALID_VISIT(UnaryMinusOperator) + INVALID_VISIT(IsNullOperator) + INVALID_VISIT(ListLiteral) + INVALID_VISIT(MapLiteral) + INVALID_VISIT(MapProjectionLiteral) + INVALID_VISIT(PropertyLookup) + INVALID_VISIT(AllPropertiesLookup) + INVALID_VISIT(LabelsTest) + INVALID_VISIT(Aggregation) + INVALID_VISIT(Function) + INVALID_VISIT(Reduce) + INVALID_VISIT(Coalesce) + INVALID_VISIT(Extract) + INVALID_VISIT(All) + INVALID_VISIT(Single) + INVALID_VISIT(Any) + INVALID_VISIT(None) + INVALID_VISIT(Identifier) + INVALID_VISIT(RegexMatch) + INVALID_VISIT(Exists) + +#undef INVALID_VISIT + private: + EvaluationContext const *ctx_; +}; class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { public: ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba, @@ -1046,6 +1113,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> { /// @throw QueryRuntimeException if expression doesn't evaluate to an int. int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what); -std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale); +std::optional<size_t> EvaluateMemoryLimit(ExpressionVisitor<TypedValue> &eval, Expression *memory_limit, + size_t memory_scale); } // namespace memgraph::query diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 487ae2eee..845259607 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -162,12 +162,12 @@ struct Callback { bool should_abort_query{false}; }; -TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluator *eval) { - return expression ? expression->Accept(*eval) : TypedValue(); +TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionVisitor<TypedValue> &eval) { + return expression ? expression->Accept(eval) : TypedValue(); } template <typename TResult> -std::optional<TResult> GetOptionalValue(query::Expression *expression, ExpressionEvaluator &evaluator) { +std::optional<TResult> GetOptionalValue(query::Expression *expression, ExpressionVisitor<TypedValue> &evaluator) { if (expression != nullptr) { auto int_value = expression->Accept(evaluator); MG_ASSERT(int_value.IsNull() || int_value.IsInt()); @@ -178,7 +178,8 @@ std::optional<TResult> GetOptionalValue(query::Expression *expression, Expressio return {}; }; -std::optional<std::string> GetOptionalStringValue(query::Expression *expression, ExpressionEvaluator &evaluator) { +std::optional<std::string> GetOptionalStringValue(query::Expression *expression, + ExpressionVisitor<TypedValue> &evaluator) { if (expression != nullptr) { auto value = expression->Accept(evaluator); MG_ASSERT(value.IsNull() || value.IsString()); @@ -356,24 +357,17 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler { /// returns false if the replication role can't be set /// @throw QueryRuntimeException if an error ocurred. -Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters, - DbAccessor *db_accessor) { +Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters ¶meters) { AuthQueryHandler *auth = interpreter_context->auth; #ifdef MG_ENTERPRISE auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context); #endif - // Empty frame for evaluation of password expression. This is OK since - // password should be either null or string literal and it's evaluation - // should not depend on frame. - Frame frame(0); - SymbolTable symbol_table; - EvaluationContext evaluation_context; // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. + EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); - evaluation_context.parameters = parameters; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; std::string username = auth_query->user_; std::string rolename = auth_query->role_; @@ -386,7 +380,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ std::vector<std::unordered_map<AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> edge_type_privileges = auth_query->edge_type_privileges_; #endif - auto password = EvaluateOptionalExpression(auth_query->password_, &evaluator); + auto password = EvaluateOptionalExpression(auth_query->password_, evaluator); Callback callback; @@ -646,21 +640,18 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_ } // namespace Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters ¶meters, - InterpreterContext *interpreter_context, DbAccessor *db_accessor, - std::vector<Notification> *notifications) { - Frame frame(0); - SymbolTable symbol_table; - EvaluationContext evaluation_context; + InterpreterContext *interpreter_context, std::vector<Notification> *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. + EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parameters; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; Callback callback; switch (repl_query->action_) { case ReplicationQuery::Action::SET_REPLICATION_ROLE: { - auto port = EvaluateOptionalExpression(repl_query->port_, &evaluator); + auto port = EvaluateOptionalExpression(repl_query->port_, evaluator); std::optional<int64_t> maybe_port; if (port.IsInt()) { maybe_port = port.ValueInt(); @@ -779,7 +770,7 @@ std::optional<std::string> StringPointerToOptional(const std::string *str) { return str == nullptr ? std::nullopt : std::make_optional(*str); } -stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionEvaluator &evaluator) { +stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator) { return { .batch_interval = GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator) .value_or(stream::kDefaultBatchInterval), @@ -787,7 +778,7 @@ stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, Expressi .transformation_name = stream_query->transform_name_}; } -std::vector<std::string> EvaluateTopicNames(ExpressionEvaluator &evaluator, +std::vector<std::string> EvaluateTopicNames(ExpressionVisitor<TypedValue> &evaluator, std::variant<Expression *, std::vector<std::string>> topic_variant) { return std::visit(utils::Overloaded{[&](Expression *expression) { auto topic_names = expression->Accept(evaluator); @@ -798,7 +789,7 @@ std::vector<std::string> EvaluateTopicNames(ExpressionEvaluator &evaluator, std::move(topic_variant)); } -Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionEvaluator &evaluator, +Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator, InterpreterContext *interpreter_context, const std::string *username) { static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer"; @@ -849,7 +840,7 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp }; } -Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionEvaluator &evaluator, +Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator, InterpreterContext *interpreter_context, const std::string *username) { auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator); @@ -875,16 +866,14 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex } Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters ¶meters, - InterpreterContext *interpreter_context, DbAccessor *db_accessor, - const std::string *username, std::vector<Notification> *notifications) { - Frame frame(0); - SymbolTable symbol_table; - EvaluationContext evaluation_context; + InterpreterContext *interpreter_context, const std::string *username, + std::vector<Notification> *notifications) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. + EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parameters; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + PrimitiveLiteralExpressionEvaluator evaluator{evaluation_context}; Callback callback; switch (stream_query->action_) { @@ -1040,27 +1029,23 @@ Callback HandleConfigQuery() { return callback; } -Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶meters, DbAccessor *db_accessor) { - Frame frame(0); - SymbolTable symbol_table; - EvaluationContext evaluation_context; +Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶meters) { // TODO: MemoryResource for EvaluationContext, it should probably be passed as // the argument to Callback. - evaluation_context.timestamp = - std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) - .count(); + EvaluationContext evaluation_context; + evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parameters; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; Callback callback; switch (setting_query->action_) { case SettingQuery::Action::SET_SETTING: { - const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator); + const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, evaluator); if (!setting_name.IsString()) { throw utils::BasicException("Setting name should be a string literal"); } - const auto setting_value = EvaluateOptionalExpression(setting_query->setting_value_, &evaluator); + const auto setting_value = EvaluateOptionalExpression(setting_query->setting_value_, evaluator); if (!setting_value.IsString()) { throw utils::BasicException("Setting value should be a string literal"); } @@ -1075,7 +1060,7 @@ Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters ¶m return callback; } case SettingQuery::Action::SHOW_SETTING: { - const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator); + const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, evaluator); if (!setting_name.IsString()) { throw utils::BasicException("Setting name should be a string literal"); } @@ -1540,14 +1525,12 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, FrameChangeCollector *frame_change_collector = nullptr) { auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); - Frame frame(0); - SymbolTable symbol_table; EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parsed_query.parameters; + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::View::OLD); - const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_); + const auto memory_limit = EvaluateMemoryLimit(evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_); if (memory_limit) { spdlog::info("Running query with memory limit of {}", utils::GetReadableSize(*memory_limit)); } @@ -1708,13 +1691,11 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra !contains_csv && !IsCallBatchedProcedureQuery(clauses) && !IsAllShortestPathsQuery(clauses); MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE"); - Frame frame(0); - SymbolTable symbol_table; EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parsed_inner_query.parameters; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::View::OLD); - const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_); + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; + const auto memory_limit = EvaluateMemoryLimit(evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_); auto cypher_query_plan = CypherQueryToPlan( parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query, @@ -2138,38 +2119,31 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans } PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, - DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username, - std::atomic<TransactionStatus> *transaction_status, - std::shared_ptr<utils::AsyncTimer> tx_timer) { + InterpreterContext *interpreter_context) { if (in_explicit_transaction) { throw UserModificationInMulticommandTxException(); } auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query); - auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, dba); + auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters); - SymbolTable symbol_table; - std::vector<Symbol> output_symbols; - for (const auto &column : callback.header) { - output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false")); - } - - auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>( - std::make_unique<plan::OutputTable>(output_symbols, - [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), - 0.0, AstStorage{}, symbol_table)); - - auto pull_plan = - std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory, - StringPointerToOptional(username), transaction_status, std::move(tx_timer)); return PreparedQuery{ - callback.header, std::move(parsed_query.required_privileges), - [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), - summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { - if (pull_plan->Pull(stream, n, output_symbols, summary)) { - return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT; + std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr), interpreter_context]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (!pull_plan) { + // Run the specific query + auto results = handler(); + pull_plan = std::make_shared<PullPlanVector>(std::move(results)); +#ifdef MG_ENTERPRISE + // Invalidate auth cache after every type of AuthQuery + interpreter_context->auth_checker->ClearCache(); +#endif + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::COMMIT; } return std::nullopt; }, @@ -2177,8 +2151,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa } PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction, - std::vector<Notification> *notifications, InterpreterContext *interpreter_context, - DbAccessor *dba) { + std::vector<Notification> *notifications, + InterpreterContext *interpreter_context) { if (in_explicit_transaction) { throw ReplicationModificationInMulticommandTxException(); } @@ -2189,7 +2163,7 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query); auto callback = - HandleReplicationQuery(replication_query, parsed_query.parameters, interpreter_context, dba, notifications); + HandleReplicationQuery(replication_query, parsed_query.parameters, interpreter_context, notifications); return PreparedQuery{callback.header, std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( @@ -2443,8 +2417,6 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction, std::vector<Notification> *notifications, InterpreterContext *interpreter_context, - DbAccessor *dba, - const std::map<std::string, storage::PropertyValue> & /*user_parameters*/, const std::string *username) { if (in_explicit_transaction) { throw StreamQueryInMulticommandTxException(); @@ -2453,7 +2425,7 @@ PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_tran auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query); MG_ASSERT(stream_query); auto callback = - HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba, username, notifications); + HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, username, notifications); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( @@ -2678,7 +2650,7 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_tra auto *setting_query = utils::Downcast<SettingQuery>(parsed_query.query); MG_ASSERT(setting_query); - auto callback = HandleSettingQuery(setting_query, parsed_query.parameters, dba); + auto callback = HandleSettingQuery(setting_query, parsed_query.parameters); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( @@ -2780,13 +2752,11 @@ std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::KillTransacti Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query, const std::optional<std::string> &username, const Parameters ¶meters, - InterpreterContext *interpreter_context, DbAccessor *db_accessor) { - Frame frame(0); - SymbolTable symbol_table; + InterpreterContext *interpreter_context) { EvaluationContext evaluation_context; evaluation_context.timestamp = QueryTimestamp(); evaluation_context.parameters = parameters; - ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD); + auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context}; bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized( username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, ""); @@ -2836,7 +2806,7 @@ PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std:: auto *transaction_queue_query = utils::Downcast<TransactionQueueQuery>(parsed_query.query); MG_ASSERT(transaction_queue_query); auto callback = - HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context, dba); + HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context); return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges), [callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}]( @@ -3384,11 +3354,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explic } PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context, - const std::string &session_uuid, std::map<std::string, TypedValue> *summary, - DbAccessor *dba, utils::MemoryResource *execution_memory, - const std::optional<std::string> &username, - std::atomic<TransactionStatus> *transaction_status, - std::shared_ptr<utils::AsyncTimer> tx_timer) { + const std::string &session_uuid, const std::optional<std::string> &username) { #ifdef MG_ENTERPRISE if (!license::global_license_checker.IsEnterpriseValidFast()) { throw QueryException("Trying to use enterprise feature without a valid license."); @@ -3453,26 +3419,17 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon return status; }; - SymbolTable symbol_table; - std::vector<Symbol> output_symbols; - for (const auto &column : callback.header) { - output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false")); - } - - auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>( - std::make_unique<plan::OutputTable>(output_symbols, - [fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }), - 0.0, AstStorage{}, symbol_table)); - - auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, - execution_memory, username, transaction_status, std::move(tx_timer)); - return PreparedQuery{ - callback.header, std::move(parsed_query.required_privileges), - [pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols), - summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { - if (pull_plan->Pull(stream, n, output_symbols, summary)) { - return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT; + std::move(callback.header), std::move(parsed_query.required_privileges), + [handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr)]( + AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { + if (!pull_plan) { + auto results = handler(); + pull_plan = std::make_shared<PullPlanVector>(std::move(results)); + } + + if (pull_plan->Pull(stream, n)) { + return QueryHandlerResult::NOTHING; } return std::nullopt; }, @@ -3649,10 +3606,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_, &*execution_db_accessor_, interpreter_context_); } else if (utils::Downcast<AuthQuery>(parsed_query.query)) { - prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, - interpreter_context_, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception, username, - &transaction_status_, std::move(current_timer)); + prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); } else if (utils::Downcast<InfoQuery>(parsed_query.query)) { prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary, interpreter_context_, interpreter_context_->db.get(), @@ -3662,9 +3616,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareConstraintQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, interpreter_context_); } else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) { - prepared_query = - PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - interpreter_context_, &*execution_db_accessor_); + prepared_query = PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, + &query_execution->notifications, interpreter_context_); } else if (utils::Downcast<LockPathQuery>(parsed_query.query)) { prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_); } else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) { @@ -3676,9 +3629,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, interpreter_context_, &*execution_db_accessor_, params, username); } else if (utils::Downcast<StreamQuery>(parsed_query.query)) { - prepared_query = - PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications, - interpreter_context_, &*execution_db_accessor_, params, username); + prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, + &query_execution->notifications, interpreter_context_, username); } else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) { prepared_query = PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this); @@ -3698,10 +3650,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), in_explicit_transaction_, in_explicit_db_, interpreter_context_, session_uuid); } else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) { - prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid, - &query_execution->summary, &*execution_db_accessor_, - &query_execution->execution_memory_with_exception, username_, - &transaction_status_, std::move(current_timer)); + prepared_query = + PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid, username_); } else { LOG_FATAL("Should not get here -- unknown query type!"); } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 227acc2df..12f111903 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -4626,7 +4626,7 @@ class CallProcedureCursor : public Cursor { // TODO: This will probably need to be changed when we add support for // generator like procedures which yield a new result on new query calls. auto *memory = self_->memory_resource; - auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_); + auto memory_limit = EvaluateMemoryLimit(evaluator, self_->memory_limit_, self_->memory_scale_); auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context); CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit, result_, call_initializer); diff --git a/tests/unit/dbms_interp.cpp b/tests/unit/dbms_interp.cpp index a71bcb87c..5dbd51150 100644 --- a/tests/unit/dbms_interp.cpp +++ b/tests/unit/dbms_interp.cpp @@ -76,6 +76,8 @@ class TestAuthChecker : public memgraph::query::AuthChecker { const std::string & /*username*/, const memgraph::query::DbAccessor * /*db_accessor*/) const override { return {}; } + + void ClearCache() const override {} }; std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_interp"}; diff --git a/tests/unit/query_trigger.cpp b/tests/unit/query_trigger.cpp index 31a28445f..4a75391f0 100644 --- a/tests/unit/query_trigger.cpp +++ b/tests/unit/query_trigger.cpp @@ -47,6 +47,7 @@ class MockAuthChecker : public memgraph::query::AuthChecker { MOCK_CONST_METHOD2(GetFineGrainedAuthChecker, std::unique_ptr<memgraph::query::FineGrainedAuthChecker>( const std::string &username, const memgraph::query::DbAccessor *db_accessor)); + MOCK_CONST_METHOD0(ClearCache, void()); #endif }; } // namespace