[master < ] Remove DbAccessor from non-transactional queries ()

* Decouple non-transactional queries from DbAccessor
* Invalidate auth cache after AuthQuery

Co-authored-by: Gareth Lloyd <gareth.lloyd@memgraph.io>
This commit is contained in:
andrejtonev 2023-08-29 11:13:42 +02:00 committed by GitHub
parent 5f509532f2
commit c526ff2a8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 170 additions and 137 deletions

View File

@ -94,20 +94,24 @@ std::unique_ptr<memgraph::query::FineGrainedAuthChecker> AuthChecker::GetFineGra
return {};
}
try {
auto locked_auth = auth_->Lock();
if (username != user_.username()) {
auto maybe_user = locked_auth->GetUser(username);
auto user = user_.Lock();
if (username != user->username()) {
auto maybe_user = auth_->ReadLock()->GetUser(username);
if (!maybe_user) {
throw memgraph::query::QueryRuntimeException("User '{}' doesn't exist .", username);
}
user_ = std::move(*maybe_user);
*user = std::move(*maybe_user);
}
return std::make_unique<memgraph::glue::FineGrainedAuthChecker>(user_, dba);
return std::make_unique<memgraph::glue::FineGrainedAuthChecker>(*user, dba);
} catch (const memgraph::auth::AuthException &e) {
throw memgraph::query::QueryRuntimeException(e.what());
}
}
void AuthChecker::ClearCache() const {
user_.WithLock([](auto &user) mutable { user = {}; });
}
#endif
bool AuthChecker::IsUserAuthorized(const memgraph::auth::User &user,

View File

@ -16,6 +16,7 @@
#include "query/auth_checker.hpp"
#include "query/db_accessor.hpp"
#include "query/frontend/ast/ast.hpp"
#include "utils/spin_lock.hpp"
namespace memgraph::glue {
@ -32,6 +33,8 @@ class AuthChecker : public query::AuthChecker {
std::unique_ptr<memgraph::query::FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *dba) const override;
void ClearCache() const override;
#endif
[[nodiscard]] static bool IsUserAuthorized(const memgraph::auth::User &user,
const std::vector<memgraph::query::AuthQuery::Privilege> &privileges,
@ -39,7 +42,7 @@ class AuthChecker : public query::AuthChecker {
private:
memgraph::utils::Synchronized<memgraph::auth::Auth, memgraph::utils::WritePrioritizedRWLock> *auth_;
mutable auth::User user_;
mutable memgraph::utils::Synchronized<auth::User, memgraph::utils::SpinLock> user_; // cached user
};
#ifdef MG_ENTERPRISE
class FineGrainedAuthChecker : public query::FineGrainedAuthChecker {

View File

@ -30,6 +30,8 @@ class AuthChecker {
#ifdef MG_ENTERPRISE
[[nodiscard]] virtual std::unique_ptr<FineGrainedAuthChecker> GetFineGrainedAuthChecker(
const std::string &username, const memgraph::query::DbAccessor *db_accessor) const = 0;
virtual void ClearCache() const = 0;
#endif
};
#ifdef MG_ENTERPRISE
@ -103,6 +105,8 @@ class AllowEverythingAuthChecker final : public query::AuthChecker {
const query::DbAccessor * /*dba*/) const override {
return std::make_unique<AllowEverythingFineGrainedAuthChecker>();
}
void ClearCache() const override {}
#endif
}; // namespace memgraph::query

View File

@ -76,7 +76,7 @@ ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::stri
// Convert the ANTLR4 parse tree into an AST.
AstStorage ast_storage;
frontend::ParsingContext context{true};
frontend::ParsingContext context{.is_query_cached = true};
frontend::CypherMainVisitor visitor(context, &ast_storage);
visitor.visit(parser->tree());

View File

@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2023 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -22,9 +22,10 @@ int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std:
}
}
std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale) {
std::optional<size_t> EvaluateMemoryLimit(ExpressionVisitor<TypedValue> &eval, Expression *memory_limit,
size_t memory_scale) {
if (!memory_limit) return std::nullopt;
auto limit_value = memory_limit->Accept(*eval);
auto limit_value = memory_limit->Accept(eval);
if (!limit_value.IsInt() || limit_value.ValueInt() <= 0)
throw QueryRuntimeException("Memory limit must be a non-negative integer.");
size_t limit = limit_value.ValueInt();

View File

@ -99,12 +99,79 @@ class ReferenceExpressionEvaluator : public ExpressionVisitor<TypedValue *> {
UNSUCCESSFUL_VISIT(RegexMatch);
UNSUCCESSFUL_VISIT(Exists);
#undef UNSUCCESSFUL_VISIT
private:
Frame *frame_;
const SymbolTable *symbol_table_;
const EvaluationContext *ctx_;
};
class PrimitiveLiteralExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public:
explicit PrimitiveLiteralExpressionEvaluator(EvaluationContext const &ctx) : ctx_(&ctx) {}
using ExpressionVisitor<TypedValue>::Visit;
TypedValue Visit(PrimitiveLiteral &literal) override {
// TODO: no need to evaluate constants, we can write it to frame in one
// of the previous phases.
return TypedValue(literal.value_, ctx_->memory);
}
TypedValue Visit(ParameterLookup &param_lookup) override {
return TypedValue(ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->memory);
}
#define INVALID_VISIT(expr_name) \
TypedValue Visit(expr_name & /*expr*/) override { \
DLOG_FATAL("Invalid expression type visited with PrimitiveLiteralExpressionEvaluator."); \
return {}; \
}
INVALID_VISIT(NamedExpression)
INVALID_VISIT(OrOperator)
INVALID_VISIT(XorOperator)
INVALID_VISIT(AndOperator)
INVALID_VISIT(NotOperator)
INVALID_VISIT(AdditionOperator)
INVALID_VISIT(SubtractionOperator)
INVALID_VISIT(MultiplicationOperator)
INVALID_VISIT(DivisionOperator)
INVALID_VISIT(ModOperator)
INVALID_VISIT(NotEqualOperator)
INVALID_VISIT(EqualOperator)
INVALID_VISIT(LessOperator)
INVALID_VISIT(GreaterOperator)
INVALID_VISIT(LessEqualOperator)
INVALID_VISIT(GreaterEqualOperator)
INVALID_VISIT(InListOperator)
INVALID_VISIT(SubscriptOperator)
INVALID_VISIT(ListSlicingOperator)
INVALID_VISIT(IfOperator)
INVALID_VISIT(UnaryPlusOperator)
INVALID_VISIT(UnaryMinusOperator)
INVALID_VISIT(IsNullOperator)
INVALID_VISIT(ListLiteral)
INVALID_VISIT(MapLiteral)
INVALID_VISIT(MapProjectionLiteral)
INVALID_VISIT(PropertyLookup)
INVALID_VISIT(AllPropertiesLookup)
INVALID_VISIT(LabelsTest)
INVALID_VISIT(Aggregation)
INVALID_VISIT(Function)
INVALID_VISIT(Reduce)
INVALID_VISIT(Coalesce)
INVALID_VISIT(Extract)
INVALID_VISIT(All)
INVALID_VISIT(Single)
INVALID_VISIT(Any)
INVALID_VISIT(None)
INVALID_VISIT(Identifier)
INVALID_VISIT(RegexMatch)
INVALID_VISIT(Exists)
#undef INVALID_VISIT
private:
EvaluationContext const *ctx_;
};
class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
public:
ExpressionEvaluator(Frame *frame, const SymbolTable &symbol_table, const EvaluationContext &ctx, DbAccessor *dba,
@ -1046,6 +1113,7 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
/// @throw QueryRuntimeException if expression doesn't evaluate to an int.
int64_t EvaluateInt(ExpressionEvaluator *evaluator, Expression *expr, const std::string &what);
std::optional<size_t> EvaluateMemoryLimit(ExpressionEvaluator *eval, Expression *memory_limit, size_t memory_scale);
std::optional<size_t> EvaluateMemoryLimit(ExpressionVisitor<TypedValue> &eval, Expression *memory_limit,
size_t memory_scale);
} // namespace memgraph::query

View File

@ -162,12 +162,12 @@ struct Callback {
bool should_abort_query{false};
};
TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionEvaluator *eval) {
return expression ? expression->Accept(*eval) : TypedValue();
TypedValue EvaluateOptionalExpression(Expression *expression, ExpressionVisitor<TypedValue> &eval) {
return expression ? expression->Accept(eval) : TypedValue();
}
template <typename TResult>
std::optional<TResult> GetOptionalValue(query::Expression *expression, ExpressionEvaluator &evaluator) {
std::optional<TResult> GetOptionalValue(query::Expression *expression, ExpressionVisitor<TypedValue> &evaluator) {
if (expression != nullptr) {
auto int_value = expression->Accept(evaluator);
MG_ASSERT(int_value.IsNull() || int_value.IsInt());
@ -178,7 +178,8 @@ std::optional<TResult> GetOptionalValue(query::Expression *expression, Expressio
return {};
};
std::optional<std::string> GetOptionalStringValue(query::Expression *expression, ExpressionEvaluator &evaluator) {
std::optional<std::string> GetOptionalStringValue(query::Expression *expression,
ExpressionVisitor<TypedValue> &evaluator) {
if (expression != nullptr) {
auto value = expression->Accept(evaluator);
MG_ASSERT(value.IsNull() || value.IsString());
@ -356,24 +357,17 @@ class ReplQueryHandler final : public query::ReplicationQueryHandler {
/// returns false if the replication role can't be set
/// @throw QueryRuntimeException if an error ocurred.
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters,
DbAccessor *db_accessor) {
Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_context, const Parameters &parameters) {
AuthQueryHandler *auth = interpreter_context->auth;
#ifdef MG_ENTERPRISE
auto &sc_handler = memgraph::dbms::SessionContextHandler::ExtractSCH(interpreter_context);
#endif
// Empty frame for evaluation of password expression. This is OK since
// password should be either null or string literal and it's evaluation
// should not depend on frame.
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
std::string username = auth_query->user_;
std::string rolename = auth_query->role_;
@ -386,7 +380,7 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
std::vector<std::unordered_map<AuthQuery::FineGrainedPrivilege, std::vector<std::string>>> edge_type_privileges =
auth_query->edge_type_privileges_;
#endif
auto password = EvaluateOptionalExpression(auth_query->password_, &evaluator);
auto password = EvaluateOptionalExpression(auth_query->password_, evaluator);
Callback callback;
@ -646,21 +640,18 @@ Callback HandleAuthQuery(AuthQuery *auth_query, InterpreterContext *interpreter_
} // namespace
Callback HandleReplicationQuery(ReplicationQuery *repl_query, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor,
std::vector<Notification> *notifications) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
InterpreterContext *interpreter_context, std::vector<Notification> *notifications) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
Callback callback;
switch (repl_query->action_) {
case ReplicationQuery::Action::SET_REPLICATION_ROLE: {
auto port = EvaluateOptionalExpression(repl_query->port_, &evaluator);
auto port = EvaluateOptionalExpression(repl_query->port_, evaluator);
std::optional<int64_t> maybe_port;
if (port.IsInt()) {
maybe_port = port.ValueInt();
@ -779,7 +770,7 @@ std::optional<std::string> StringPointerToOptional(const std::string *str) {
return str == nullptr ? std::nullopt : std::make_optional(*str);
}
stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionEvaluator &evaluator) {
stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator) {
return {
.batch_interval = GetOptionalValue<std::chrono::milliseconds>(stream_query->batch_interval_, evaluator)
.value_or(stream::kDefaultBatchInterval),
@ -787,7 +778,7 @@ stream::CommonStreamInfo GetCommonStreamInfo(StreamQuery *stream_query, Expressi
.transformation_name = stream_query->transform_name_};
}
std::vector<std::string> EvaluateTopicNames(ExpressionEvaluator &evaluator,
std::vector<std::string> EvaluateTopicNames(ExpressionVisitor<TypedValue> &evaluator,
std::variant<Expression *, std::vector<std::string>> topic_variant) {
return std::visit(utils::Overloaded{[&](Expression *expression) {
auto topic_names = expression->Accept(evaluator);
@ -798,7 +789,7 @@ std::vector<std::string> EvaluateTopicNames(ExpressionEvaluator &evaluator,
std::move(topic_variant));
}
Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionEvaluator &evaluator,
Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator,
InterpreterContext *interpreter_context,
const std::string *username) {
static constexpr std::string_view kDefaultConsumerGroup = "mg_consumer";
@ -849,7 +840,7 @@ Callback::CallbackFunction GetKafkaCreateCallback(StreamQuery *stream_query, Exp
};
}
Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionEvaluator &evaluator,
Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, ExpressionVisitor<TypedValue> &evaluator,
InterpreterContext *interpreter_context,
const std::string *username) {
auto service_url = GetOptionalStringValue(stream_query->service_url_, evaluator);
@ -875,16 +866,14 @@ Callback::CallbackFunction GetPulsarCreateCallback(StreamQuery *stream_query, Ex
}
Callback HandleStreamQuery(StreamQuery *stream_query, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor,
const std::string *username, std::vector<Notification> *notifications) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
InterpreterContext *interpreter_context, const std::string *username,
std::vector<Notification> *notifications) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
PrimitiveLiteralExpressionEvaluator evaluator{evaluation_context};
Callback callback;
switch (stream_query->action_) {
@ -1040,27 +1029,23 @@ Callback HandleConfigQuery() {
return callback;
}
Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters &parameters, DbAccessor *db_accessor) {
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters &parameters) {
// TODO: MemoryResource for EvaluationContext, it should probably be passed as
// the argument to Callback.
evaluation_context.timestamp =
std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch())
.count();
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
Callback callback;
switch (setting_query->action_) {
case SettingQuery::Action::SET_SETTING: {
const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator);
const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, evaluator);
if (!setting_name.IsString()) {
throw utils::BasicException("Setting name should be a string literal");
}
const auto setting_value = EvaluateOptionalExpression(setting_query->setting_value_, &evaluator);
const auto setting_value = EvaluateOptionalExpression(setting_query->setting_value_, evaluator);
if (!setting_value.IsString()) {
throw utils::BasicException("Setting value should be a string literal");
}
@ -1075,7 +1060,7 @@ Callback HandleSettingQuery(SettingQuery *setting_query, const Parameters &param
return callback;
}
case SettingQuery::Action::SHOW_SETTING: {
const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, &evaluator);
const auto setting_name = EvaluateOptionalExpression(setting_query->setting_name_, evaluator);
if (!setting_name.IsString()) {
throw utils::BasicException("Setting name should be a string literal");
}
@ -1540,14 +1525,12 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string,
FrameChangeCollector *frame_change_collector = nullptr) {
auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query);
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parsed_query.parameters;
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::View::OLD);
const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_);
const auto memory_limit = EvaluateMemoryLimit(evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_);
if (memory_limit) {
spdlog::info("Running query with memory limit of {}", utils::GetReadableSize(*memory_limit));
}
@ -1708,13 +1691,11 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra
!contains_csv && !IsCallBatchedProcedureQuery(clauses) && !IsAllShortestPathsQuery(clauses);
MG_ASSERT(cypher_query, "Cypher grammar should not allow other queries in PROFILE");
Frame frame(0);
SymbolTable symbol_table;
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parsed_inner_query.parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, dba, storage::View::OLD);
const auto memory_limit = EvaluateMemoryLimit(&evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_);
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
const auto memory_limit = EvaluateMemoryLimit(evaluator, cypher_query->memory_limit_, cypher_query->memory_scale_);
auto cypher_query_plan = CypherQueryToPlan(
parsed_inner_query.stripped_query.hash(), std::move(parsed_inner_query.ast_storage), cypher_query,
@ -2138,38 +2119,31 @@ PreparedQuery PrepareIndexQuery(ParsedQuery parsed_query, bool in_explicit_trans
}
PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context,
DbAccessor *dba, utils::MemoryResource *execution_memory, const std::string *username,
std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer) {
InterpreterContext *interpreter_context) {
if (in_explicit_transaction) {
throw UserModificationInMulticommandTxException();
}
auto *auth_query = utils::Downcast<AuthQuery>(parsed_query.query);
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters, dba);
auto callback = HandleAuthQuery(auth_query, interpreter_context, parsed_query.parameters);
SymbolTable symbol_table;
std::vector<Symbol> output_symbols;
for (const auto &column : callback.header) {
output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false"));
}
auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>(
std::make_unique<plan::OutputTable>(output_symbols,
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, execution_memory,
StringPointerToOptional(username), transaction_status, std::move(tx_timer));
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n, output_symbols, summary)) {
return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT;
std::move(callback.header), std::move(parsed_query.required_privileges),
[handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr), interpreter_context](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (!pull_plan) {
// Run the specific query
auto results = handler();
pull_plan = std::make_shared<PullPlanVector>(std::move(results));
#ifdef MG_ENTERPRISE
// Invalidate auth cache after every type of AuthQuery
interpreter_context->auth_checker->ClearCache();
#endif
}
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
},
@ -2177,8 +2151,8 @@ PreparedQuery PrepareAuthQuery(ParsedQuery parsed_query, bool in_explicit_transa
}
PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba) {
std::vector<Notification> *notifications,
InterpreterContext *interpreter_context) {
if (in_explicit_transaction) {
throw ReplicationModificationInMulticommandTxException();
}
@ -2189,7 +2163,7 @@ PreparedQuery PrepareReplicationQuery(ParsedQuery parsed_query, bool in_explicit
auto *replication_query = utils::Downcast<ReplicationQuery>(parsed_query.query);
auto callback =
HandleReplicationQuery(replication_query, parsed_query.parameters, interpreter_context, dba, notifications);
HandleReplicationQuery(replication_query, parsed_query.parameters, interpreter_context, notifications);
return PreparedQuery{callback.header, std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2443,8 +2417,6 @@ PreparedQuery PrepareTriggerQuery(ParsedQuery parsed_query, bool in_explicit_tra
PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_transaction,
std::vector<Notification> *notifications, InterpreterContext *interpreter_context,
DbAccessor *dba,
const std::map<std::string, storage::PropertyValue> & /*user_parameters*/,
const std::string *username) {
if (in_explicit_transaction) {
throw StreamQueryInMulticommandTxException();
@ -2453,7 +2425,7 @@ PreparedQuery PrepareStreamQuery(ParsedQuery parsed_query, bool in_explicit_tran
auto *stream_query = utils::Downcast<StreamQuery>(parsed_query.query);
MG_ASSERT(stream_query);
auto callback =
HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, dba, username, notifications);
HandleStreamQuery(stream_query, parsed_query.parameters, interpreter_context, username, notifications);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2678,7 +2650,7 @@ PreparedQuery PrepareSettingQuery(ParsedQuery parsed_query, bool in_explicit_tra
auto *setting_query = utils::Downcast<SettingQuery>(parsed_query.query);
MG_ASSERT(setting_query);
auto callback = HandleSettingQuery(setting_query, parsed_query.parameters, dba);
auto callback = HandleSettingQuery(setting_query, parsed_query.parameters);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -2780,13 +2752,11 @@ std::vector<std::vector<TypedValue>> TransactionQueueQueryHandler::KillTransacti
Callback HandleTransactionQueueQuery(TransactionQueueQuery *transaction_query,
const std::optional<std::string> &username, const Parameters &parameters,
InterpreterContext *interpreter_context, DbAccessor *db_accessor) {
Frame frame(0);
SymbolTable symbol_table;
InterpreterContext *interpreter_context) {
EvaluationContext evaluation_context;
evaluation_context.timestamp = QueryTimestamp();
evaluation_context.parameters = parameters;
ExpressionEvaluator evaluator(&frame, symbol_table, evaluation_context, db_accessor, storage::View::OLD);
auto evaluator = PrimitiveLiteralExpressionEvaluator{evaluation_context};
bool hasTransactionManagementPrivilege = interpreter_context->auth_checker->IsUserAuthorized(
username, {query::AuthQuery::Privilege::TRANSACTION_MANAGEMENT}, "");
@ -2836,7 +2806,7 @@ PreparedQuery PrepareTransactionQueueQuery(ParsedQuery parsed_query, const std::
auto *transaction_queue_query = utils::Downcast<TransactionQueueQuery>(parsed_query.query);
MG_ASSERT(transaction_queue_query);
auto callback =
HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context, dba);
HandleTransactionQueueQuery(transaction_queue_query, username, parsed_query.parameters, interpreter_context);
return PreparedQuery{std::move(callback.header), std::move(parsed_query.required_privileges),
[callback_fn = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>{nullptr}](
@ -3384,11 +3354,7 @@ PreparedQuery PrepareMultiDatabaseQuery(ParsedQuery parsed_query, bool in_explic
}
PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterContext *interpreter_context,
const std::string &session_uuid, std::map<std::string, TypedValue> *summary,
DbAccessor *dba, utils::MemoryResource *execution_memory,
const std::optional<std::string> &username,
std::atomic<TransactionStatus> *transaction_status,
std::shared_ptr<utils::AsyncTimer> tx_timer) {
const std::string &session_uuid, const std::optional<std::string> &username) {
#ifdef MG_ENTERPRISE
if (!license::global_license_checker.IsEnterpriseValidFast()) {
throw QueryException("Trying to use enterprise feature without a valid license.");
@ -3453,26 +3419,17 @@ PreparedQuery PrepareShowDatabasesQuery(ParsedQuery parsed_query, InterpreterCon
return status;
};
SymbolTable symbol_table;
std::vector<Symbol> output_symbols;
for (const auto &column : callback.header) {
output_symbols.emplace_back(symbol_table.CreateSymbol(column, "false"));
}
auto plan = std::make_shared<CachedPlan>(std::make_unique<SingleNodeLogicalPlan>(
std::make_unique<plan::OutputTable>(output_symbols,
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));
auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context,
execution_memory, username, transaction_status, std::move(tx_timer));
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[pull_plan = std::move(pull_plan), callback = std::move(callback), output_symbols = std::move(output_symbols),
summary](AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n, output_symbols, summary)) {
return callback.should_abort_query ? QueryHandlerResult::ABORT : QueryHandlerResult::COMMIT;
std::move(callback.header), std::move(parsed_query.required_privileges),
[handler = std::move(callback.fn), pull_plan = std::shared_ptr<PullPlanVector>(nullptr)](
AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (!pull_plan) {
auto results = handler();
pull_plan = std::make_shared<PullPlanVector>(std::move(results));
}
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::NOTHING;
}
return std::nullopt;
},
@ -3649,10 +3606,7 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareAnalyzeGraphQuery(std::move(parsed_query), in_explicit_transaction_,
&*execution_db_accessor_, interpreter_context_);
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception, username,
&transaction_status_, std::move(current_timer));
prepared_query = PrepareAuthQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->summary,
interpreter_context_, interpreter_context_->db.get(),
@ -3662,9 +3616,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareConstraintQuery(std::move(parsed_query), in_explicit_transaction_,
&query_execution->notifications, interpreter_context_);
} else if (utils::Downcast<ReplicationQuery>(parsed_query.query)) {
prepared_query =
PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
interpreter_context_, &*execution_db_accessor_);
prepared_query = PrepareReplicationQuery(std::move(parsed_query), in_explicit_transaction_,
&query_execution->notifications, interpreter_context_);
} else if (utils::Downcast<LockPathQuery>(parsed_query.query)) {
prepared_query = PrepareLockPathQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_);
} else if (utils::Downcast<FreeMemoryQuery>(parsed_query.query)) {
@ -3676,9 +3629,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
PrepareTriggerQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
interpreter_context_, &*execution_db_accessor_, params, username);
} else if (utils::Downcast<StreamQuery>(parsed_query.query)) {
prepared_query =
PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_, &query_execution->notifications,
interpreter_context_, &*execution_db_accessor_, params, username);
prepared_query = PrepareStreamQuery(std::move(parsed_query), in_explicit_transaction_,
&query_execution->notifications, interpreter_context_, username);
} else if (utils::Downcast<IsolationLevelQuery>(parsed_query.query)) {
prepared_query =
PrepareIsolationLevelQuery(std::move(parsed_query), in_explicit_transaction_, interpreter_context_, this);
@ -3698,10 +3650,8 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string,
prepared_query = PrepareMultiDatabaseQuery(std::move(parsed_query), in_explicit_transaction_, in_explicit_db_,
interpreter_context_, session_uuid);
} else if (utils::Downcast<ShowDatabasesQuery>(parsed_query.query)) {
prepared_query = PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid,
&query_execution->summary, &*execution_db_accessor_,
&query_execution->execution_memory_with_exception, username_,
&transaction_status_, std::move(current_timer));
prepared_query =
PrepareShowDatabasesQuery(std::move(parsed_query), interpreter_context_, session_uuid, username_);
} else {
LOG_FATAL("Should not get here -- unknown query type!");
}

View File

@ -4626,7 +4626,7 @@ class CallProcedureCursor : public Cursor {
// TODO: This will probably need to be changed when we add support for
// generator like procedures which yield a new result on new query calls.
auto *memory = self_->memory_resource;
auto memory_limit = EvaluateMemoryLimit(&evaluator, self_->memory_limit_, self_->memory_scale_);
auto memory_limit = EvaluateMemoryLimit(evaluator, self_->memory_limit_, self_->memory_scale_);
auto graph = mgp_graph::WritableGraph(*context.db_accessor, graph_view, context);
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_, graph, &evaluator, memory, memory_limit,
result_, call_initializer);

View File

@ -76,6 +76,8 @@ class TestAuthChecker : public memgraph::query::AuthChecker {
const std::string & /*username*/, const memgraph::query::DbAccessor * /*db_accessor*/) const override {
return {};
}
void ClearCache() const override {}
};
std::filesystem::path storage_directory{std::filesystem::temp_directory_path() / "MG_test_unit_dbms_interp"};

View File

@ -47,6 +47,7 @@ class MockAuthChecker : public memgraph::query::AuthChecker {
MOCK_CONST_METHOD2(GetFineGrainedAuthChecker,
std::unique_ptr<memgraph::query::FineGrainedAuthChecker>(
const std::string &username, const memgraph::query::DbAccessor *db_accessor));
MOCK_CONST_METHOD0(ClearCache, void());
#endif
};
} // namespace