diff --git a/src/query/context.hpp b/src/query/context.hpp index d937769da..286e5adf5 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -4,6 +4,7 @@ #include "query/frontend/semantic/symbol_table.hpp" #include "query/parameters.hpp" #include "query/plan/profile.hpp" +#include "query/trigger.hpp" #include "utils/tsc.hpp" namespace query { @@ -56,6 +57,9 @@ struct ExecutionContext { std::chrono::duration<double> profile_execution_time; plan::ProfilingStats stats; plan::ProfilingStats *stats_root{nullptr}; + + // trigger context + TriggerContext *trigger_context{nullptr}; }; inline bool MustAbort(const ExecutionContext &context) { diff --git a/src/query/cypher_query_interpreter.cpp b/src/query/cypher_query_interpreter.cpp index d721f2c09..b86219364 100644 --- a/src/query/cypher_query_interpreter.cpp +++ b/src/query/cypher_query_interpreter.cpp @@ -103,19 +103,23 @@ ParsedQuery ParseQuery(const std::string &query_string, const std::map<std::stri is_cacheable}; } +namespace { std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters ¶meters, - DbAccessor *db_accessor) { + DbAccessor *db_accessor, + const std::vector<Identifier *> &predefined_identifiers) { auto vertex_counts = plan::MakeVertexCountCache(db_accessor); - auto symbol_table = MakeSymbolTable(query); + auto symbol_table = MakeSymbolTable(query, predefined_identifiers); auto planning_context = plan::MakePlanningContext(&ast_storage, &symbol_table, query, &vertex_counts); auto [root, cost] = plan::MakeLogicalPlan(&planning_context, parameters, FLAGS_query_cost_planner); return std::make_unique<SingleNodeLogicalPlan>(std::move(root), cost, std::move(ast_storage), std::move(symbol_table)); } +} // namespace std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query, const Parameters ¶meters, utils::SkipList<PlanCacheEntry> *plan_cache, - DbAccessor *db_accessor, const bool is_cacheable) { + DbAccessor *db_accessor, const bool is_cacheable, + const std::vector<Identifier *> &predefined_identifiers) { auto plan_cache_access = plan_cache->access(); auto it = plan_cache_access.find(hash); if (it != plan_cache_access.end()) { @@ -126,7 +130,8 @@ std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_stor } } - auto plan = std::make_shared<CachedPlan>(MakeLogicalPlan(std::move(ast_storage), (query), parameters, db_accessor)); + auto plan = std::make_shared<CachedPlan>( + MakeLogicalPlan(std::move(ast_storage), query, parameters, db_accessor, predefined_identifiers)); if (is_cacheable) { plan_cache_access.insert({hash, plan}); } diff --git a/src/query/cypher_query_interpreter.hpp b/src/query/cypher_query_interpreter.hpp index eaf688521..5170cd93c 100644 --- a/src/query/cypher_query_interpreter.hpp +++ b/src/query/cypher_query_interpreter.hpp @@ -129,21 +129,17 @@ class SingleNodeLogicalPlan final : public LogicalPlan { SymbolTable symbol_table_; }; -/** - * Convert a parsed *Cypher* query's AST into a logical plan. - * - * The created logical plan will take ownership of the `AstStorage` within - * `ParsedQuery` and might modify it during planning. - */ -std::unique_ptr<LogicalPlan> MakeLogicalPlan(AstStorage ast_storage, CypherQuery *query, const Parameters ¶meters, - DbAccessor *db_accessor); - /** * Return the parsed *Cypher* query's AST cached logical plan, or create and * cache a fresh one if it doesn't yet exist. + * @param predefined_identifiers optional identifiers you want to inject into a query. + * If an identifier is not defined in a scope, we check the predefined identifiers. + * If an identifier is contained there, we inject it at that place and remove it, + * because a predefined identifier can be used only in one scope. */ std::shared_ptr<CachedPlan> CypherQueryToPlan(uint64_t hash, AstStorage ast_storage, CypherQuery *query, const Parameters ¶meters, utils::SkipList<PlanCacheEntry> *plan_cache, - DbAccessor *db_accessor, bool is_cacheable = true); + DbAccessor *db_accessor, bool is_cacheable = true, + const std::vector<Identifier *> &predefined_identifiers = {}); } // namespace query diff --git a/src/query/db_accessor.hpp b/src/query/db_accessor.hpp index 0c7f24977..a3ec2b019 100644 --- a/src/query/db_accessor.hpp +++ b/src/query/db_accessor.hpp @@ -208,6 +208,8 @@ class DbAccessor final { return std::nullopt; } + void FinalizeTransaction() { accessor_->FinalizeTransaction(); } + VerticesIterable Vertices(storage::View view) { return VerticesIterable(accessor_->Vertices(view)); } VerticesIterable Vertices(storage::View view, storage::LabelId label) { diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index ee4fdcebd..4f9fdcf4c 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -686,9 +686,7 @@ cpp<# symbol_pos_ = symbol.position(); return this; } - cpp<#) - (:protected - #>cpp + explicit Identifier(const std::string &name) : name_(name) {} Identifier(const std::string &name, bool user_declared) : name_(name), user_declared_(user_declared) {} diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index ea6a0f8b9..96aff3646 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -12,8 +12,23 @@ namespace query { +namespace { +std::unordered_map<std::string, Identifier *> GeneratePredefinedIdentifierMap( + const std::vector<Identifier *> &predefined_identifiers) { + std::unordered_map<std::string, Identifier *> identifier_map; + for (const auto &identifier : predefined_identifiers) { + identifier_map.emplace(identifier->name_, identifier); + } + + return identifier_map; +} +} // namespace + +SymbolGenerator::SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers) + : symbol_table_(symbol_table), predefined_identifiers_{GeneratePredefinedIdentifierMap(predefined_identifiers)} {} + auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type, int token_position) { - auto symbol = symbol_table_.CreateSymbol(name, user_declared, type, token_position); + auto symbol = symbol_table_->CreateSymbol(name, user_declared, type, token_position); scope_.symbols[name] = symbol; return symbol; } @@ -227,7 +242,8 @@ bool SymbolGenerator::PostVisit(Match &) { // Check variables in property maps after visiting Match, so that they can // reference symbols out of bind order. for (auto &ident : scope_.identifiers_in_match) { - if (!HasSymbol(ident->name_)) throw UnboundVariableError(ident->name_); + if (!HasSymbol(ident->name_) && !ConsumePredefinedIdentifier(ident->name_)) + throw UnboundVariableError(ident->name_); ident->MapTo(scope_.symbols[ident->name_]); } scope_.identifiers_in_match.clear(); @@ -277,7 +293,7 @@ SymbolGenerator::ReturnType SymbolGenerator::Visit(Identifier &ident) { scope_.identifiers_in_match.emplace_back(&ident); } else { // Everything else references a bound symbol. - if (!HasSymbol(ident.name_)) throw UnboundVariableError(ident.name_); + if (!HasSymbol(ident.name_) && !ConsumePredefinedIdentifier(ident.name_)) throw UnboundVariableError(ident.name_); symbol = scope_.symbols[ident.name_]; } ident.MapTo(symbol); @@ -448,10 +464,10 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { // Create inner symbols, but don't bind them in scope, since they are to // be used in the missing filter expression. auto *inner_edge = edge_atom.filter_lambda_.inner_edge; - inner_edge->MapTo(symbol_table_.CreateSymbol(inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE)); + inner_edge->MapTo(symbol_table_->CreateSymbol(inner_edge->name_, inner_edge->user_declared_, Symbol::Type::EDGE)); auto *inner_node = edge_atom.filter_lambda_.inner_node; inner_node->MapTo( - symbol_table_.CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX)); + symbol_table_->CreateSymbol(inner_node->name_, inner_node->user_declared_, Symbol::Type::VERTEX)); } if (edge_atom.weight_lambda_.expression) { VisitWithIdentifiers(edge_atom.weight_lambda_.expression, @@ -506,4 +522,20 @@ void SymbolGenerator::VisitWithIdentifiers(Expression *expr, const std::vector<I bool SymbolGenerator::HasSymbol(const std::string &name) { return scope_.symbols.find(name) != scope_.symbols.end(); } +bool SymbolGenerator::ConsumePredefinedIdentifier(const std::string &name) { + auto it = predefined_identifiers_.find(name); + + if (it == predefined_identifiers_.end()) { + return false; + } + + // we can only use the predefined identifier in a single scope so we remove it after creating + // a symbol for it + auto &identifier = it->second; + MG_ASSERT(!identifier->user_declared_, "Predefined symbols cannot be user declared!"); + identifier->MapTo(CreateSymbol(identifier->name_, identifier->user_declared_)); + predefined_identifiers_.erase(it); + return true; +} + } // namespace query diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 0103c7f90..0da1082ec 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -17,7 +17,7 @@ namespace query { /// variable types. class SymbolGenerator : public HierarchicalTreeVisitor { public: - explicit SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {} + explicit SymbolGenerator(SymbolTable *symbol_table, const std::vector<Identifier *> &predefined_identifiers); using HierarchicalTreeVisitor::PostVisit; using HierarchicalTreeVisitor::PreVisit; @@ -116,6 +116,9 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool HasSymbol(const std::string &name); + // @return true if it added a predefined identifier with that name + bool ConsumePredefinedIdentifier(const std::string &name); + // Returns a freshly generated symbol. Previous mapping of the same name to a // different symbol is replaced with the new one. auto CreateSymbol(const std::string &name, bool user_declared, Symbol::Type type = Symbol::Type::ANY, @@ -129,15 +132,19 @@ class SymbolGenerator : public HierarchicalTreeVisitor { void VisitWithIdentifiers(Expression *, const std::vector<Identifier *> &); - SymbolTable &symbol_table_; + SymbolTable *symbol_table_; + + // Identifiers which are injected from outside the query. Each identifier + // is mapped by its name. + std::unordered_map<std::string, Identifier *> predefined_identifiers_; Scope scope_; std::unordered_set<std::string> prev_return_names_; std::unordered_set<std::string> curr_return_names_; }; -inline SymbolTable MakeSymbolTable(CypherQuery *query) { +inline SymbolTable MakeSymbolTable(CypherQuery *query, const std::vector<Identifier *> &predefined_identifiers = {}) { SymbolTable symbol_table; - SymbolGenerator symbol_generator(symbol_table); + SymbolGenerator symbol_generator(&symbol_table, predefined_identifiers); query->single_query_->Accept(symbol_generator); for (auto *cypher_union : query->cypher_unions_) { cypher_union->Accept(symbol_generator); diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 5aeaf0cf4..00c9c8520 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -467,7 +467,7 @@ struct PullPlanVector { struct PullPlan { explicit PullPlan(std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - std::optional<size_t> memory_limit = {}); + TriggerContext *trigger_context = nullptr, std::optional<size_t> memory_limit = {}); std::optional<ExecutionContext> Pull(AnyStream *stream, std::optional<int> n, const std::vector<Symbol> &output_symbols, std::map<std::string, TypedValue> *summary); @@ -495,7 +495,7 @@ struct PullPlan { PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters ¶meters, const bool is_profile_query, DbAccessor *dba, InterpreterContext *interpreter_context, utils::MemoryResource *execution_memory, - const std::optional<size_t> memory_limit) + TriggerContext *trigger_context, const std::optional<size_t> memory_limit) : plan_(plan), cursor_(plan->plan().MakeCursor(execution_memory)), frame_(plan->symbol_table().max_position(), execution_memory), @@ -512,6 +512,7 @@ PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan, const Parameters &par ctx_.max_execution_time_sec = interpreter_context->execution_timeout_sec; ctx_.is_shutting_down = &interpreter_context->is_shutting_down; ctx_.is_profile_query = is_profile_query; + ctx_.trigger_context = trigger_context; } std::optional<ExecutionContext> PullPlan::Pull(AnyStream *stream, std::optional<int> n, @@ -589,7 +590,7 @@ std::optional<ExecutionContext> PullPlan::Pull(AnyStream *stream, std::optional< summary->insert_or_assign("plan_execution_time", execution_time_.count()); cursor_->Shutdown(); ctx_.profile_execution_time = execution_time_; - return ctx_; + return std::move(ctx_); } using RWType = plan::ReadWriteTypeChecker::RWType; @@ -610,8 +611,8 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) in_explicit_transaction_ = true; expect_rollback_ = false; - db_accessor_.emplace(interpreter_context_->db->Access()); - execution_db_accessor_.emplace(&*db_accessor_); + db_accessor_ = std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access()); + execution_db_accessor_.emplace(db_accessor_.get()); }; } else if (query_upper == "COMMIT") { handler = [this] { @@ -658,7 +659,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(std::string_view query_upper) PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary, InterpreterContext *interpreter_context, DbAccessor *dba, - utils::MemoryResource *execution_memory) { + utils::MemoryResource *execution_memory, TriggerContext *trigger_context = nullptr) { auto *cypher_query = utils::Downcast<CypherQuery>(parsed_query.query); Frame frame(0); @@ -695,7 +696,7 @@ PreparedQuery PrepareCypherQuery(ParsedQuery parsed_query, std::map<std::string, } auto pull_plan = std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba, interpreter_context, - execution_memory, memory_limit); + execution_memory, trigger_context, memory_limit); return PreparedQuery{std::move(header), std::move(parsed_query.required_privileges), [pull_plan = std::move(pull_plan), output_symbols = std::move(output_symbols), summary]( AnyStream *stream, std::optional<int> n) -> std::optional<QueryHandlerResult> { @@ -820,7 +821,7 @@ PreparedQuery PrepareProfileQuery(ParsedQuery parsed_query, bool in_explicit_tra AnyStream *stream, std::optional<int> n) mutable -> std::optional<QueryHandlerResult> { // No output symbols are given so that nothing is streamed. if (!ctx) { - ctx = PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, memory_limit) + ctx = PullPlan(plan, parameters, true, dba, interpreter_context, execution_memory, nullptr, memory_limit) .Pull(stream, {}, {}, summary); pull_plan = std::make_shared<PullPlanVector>(ProfilingStatsToTable(ctx->stats, ctx->profile_execution_time)); } @@ -1322,16 +1323,22 @@ Interpreter::PrepareResult Interpreter::Prepare(const std::string &query_string, if (!in_explicit_transaction_ && (utils::Downcast<CypherQuery>(parsed_query.query) || utils::Downcast<ExplainQuery>(parsed_query.query) || utils::Downcast<ProfileQuery>(parsed_query.query) || utils::Downcast<DumpQuery>(parsed_query.query))) { - db_accessor_.emplace(interpreter_context_->db->Access()); - execution_db_accessor_.emplace(&*db_accessor_); + db_accessor_ = std::make_unique<storage::Storage::Accessor>(interpreter_context_->db->Access()); + execution_db_accessor_.emplace(db_accessor_.get()); } utils::Timer planning_timer; PreparedQuery prepared_query; if (utils::Downcast<CypherQuery>(parsed_query.query)) { + if (interpreter_context_->before_commit_triggers.size() > 0 || + interpreter_context_->after_commit_triggers.size() > 0) { + trigger_context_.emplace(); + } + prepared_query = PrepareCypherQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, - &*execution_db_accessor_, &query_execution->execution_memory); + &*execution_db_accessor_, &query_execution->execution_memory, + trigger_context_ ? &*trigger_context_ : nullptr); } else if (utils::Downcast<ExplainQuery>(parsed_query.query)) { prepared_query = PrepareExplainQuery(std::move(parsed_query), &query_execution->summary, interpreter_context_, &*execution_db_accessor_, &query_execution->execution_memory); @@ -1398,11 +1405,13 @@ void Interpreter::Abort() { if (!db_accessor_) return; db_accessor_->Abort(); execution_db_accessor_ = std::nullopt; - db_accessor_ = std::nullopt; + db_accessor_.reset(); + trigger_context_.reset(); } namespace { -void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context) { +void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, InterpreterContext *interpreter_context, + TriggerContext trigger_context) { // Run the triggers for (const auto &trigger : triggers.access()) { spdlog::debug("Executing trigger '{}'", trigger.name()); @@ -1412,12 +1421,13 @@ void RunTriggersIndividually(const utils::SkipList<Trigger> &triggers, Interpret auto storage_acc = interpreter_context->db->Access(); DbAccessor db_accessor{&storage_acc}; + trigger_context.AdaptForAccessor(&db_accessor); try { trigger.Execute(&interpreter_context->plan_cache, &db_accessor, &execution_memory, *interpreter_context->tsc_frequency, interpreter_context->execution_timeout_sec, - &interpreter_context->is_shutting_down); + &interpreter_context->is_shutting_down, trigger_context); } catch (const utils::BasicException &exception) { - spdlog::warn("Trigger {} failed with exception:\n{}", trigger.name(), exception.what()); + spdlog::warn("Trigger '{}' failed with exception:\n{}", trigger.name(), exception.what()); db_accessor.Abort(); continue; } @@ -1457,15 +1467,17 @@ void Interpreter::Commit() { // a query. if (!db_accessor_) return; - // Run the triggers - for (const auto &trigger : interpreter_context_->before_commit_triggers.access()) { - spdlog::debug("Executing trigger '{}'", trigger.name()); - utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; - trigger.Execute(&interpreter_context_->plan_cache, &*execution_db_accessor_, &execution_memory, - *interpreter_context_->tsc_frequency, interpreter_context_->execution_timeout_sec, - &interpreter_context_->is_shutting_down); + if (trigger_context_) { + // Run the triggers + for (const auto &trigger : interpreter_context_->before_commit_triggers.access()) { + spdlog::debug("Executing trigger '{}'", trigger.name()); + utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; + trigger.Execute(&interpreter_context_->plan_cache, &*execution_db_accessor_, &execution_memory, + *interpreter_context_->tsc_frequency, interpreter_context_->execution_timeout_sec, + &interpreter_context_->is_shutting_down, *trigger_context_); + } + SPDLOG_DEBUG("Finished executing before commit triggers"); } - SPDLOG_DEBUG("Finished executing before commit triggers"); auto maybe_constraint_violation = db_accessor_->Commit(); if (maybe_constraint_violation.HasError()) { @@ -1475,8 +1487,9 @@ void Interpreter::Commit() { auto label_name = execution_db_accessor_->LabelToName(constraint_violation.label); MG_ASSERT(constraint_violation.properties.size() == 1U); auto property_name = execution_db_accessor_->PropertyToName(*constraint_violation.properties.begin()); - execution_db_accessor_ = std::nullopt; - db_accessor_ = std::nullopt; + execution_db_accessor_.reset(); + db_accessor_.reset(); + trigger_context_.reset(); throw QueryException("Unable to commit due to existence constraint violation on :{}({})", label_name, property_name); break; @@ -1487,8 +1500,9 @@ void Interpreter::Commit() { utils::PrintIterable( property_names_stream, constraint_violation.properties, ", ", [this](auto &stream, const auto &prop) { stream << execution_db_accessor_->PropertyToName(prop); }); - execution_db_accessor_ = std::nullopt; - db_accessor_ = std::nullopt; + execution_db_accessor_.reset(); + db_accessor_.reset(); + trigger_context_.reset(); throw QueryException("Unable to commit due to unique constraint violation on :{}({})", label_name, property_names_stream.str()); break; @@ -1496,13 +1510,20 @@ void Interpreter::Commit() { } } - execution_db_accessor_ = std::nullopt; - db_accessor_ = std::nullopt; + if (trigger_context_) { + background_thread_.AddTask([trigger_context = std::move(*trigger_context_), + interpreter_context = this->interpreter_context_, + user_transaction = std::shared_ptr(std::move(db_accessor_))]() mutable { + RunTriggersIndividually(interpreter_context->after_commit_triggers, interpreter_context, + std::move(trigger_context)); + user_transaction->FinalizeTransaction(); + SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) + }); + } - background_thread_.AddTask([interpreter_context = this->interpreter_context_] { - RunTriggersIndividually(interpreter_context->after_commit_triggers, interpreter_context); - SPDLOG_DEBUG("Finished executing after commit triggers"); // NOLINT(bugprone-lambda-function-name) - }); + execution_db_accessor_.reset(); + db_accessor_.reset(); + trigger_context_.reset(); SPDLOG_DEBUG("Finished comitting the transaction"); } diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 8200201d7..0eff59c4d 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -148,14 +148,26 @@ struct PreparedQuery { */ struct InterpreterContext { explicit InterpreterContext(storage::Storage *db) : db(db) { - // { - // auto triggers_acc = before_commit_triggers.access(); - // triggers_acc.insert(Trigger{"BeforeCreator", "CREATE (:BEFORE)", &ast_cache, &antlr_lock}); - // } - // { - // auto triggers_acc = after_commit_triggers.access(); - // triggers_acc.insert(Trigger{"AfterCreator", "CREATE (:AFTER)", &ast_cache, &antlr_lock}); - // } + // try { + // { + // auto storage_acc = db->Access(); + // DbAccessor dba(&storage_acc); + // auto triggers_acc = before_commit_triggers.access(); + // triggers_acc.insert(Trigger{"BeforeCreator", "UNWIND createdVertices as u SET u.before = u.id + 1", + // &ast_cache, + // &plan_cache, &dba, &antlr_lock}); + // } + // { + // auto storage_acc = db->Access(); + // DbAccessor dba(&storage_acc); + // auto triggers_acc = after_commit_triggers.access(); + // triggers_acc.insert(Trigger{"AfterCreator", "UNWIND createdVertices as u SET u.after = u.id - 1", + // &ast_cache, + // &plan_cache, &dba, &antlr_lock}); + // } + // } catch (const utils::BasicException &e) { + // spdlog::critical("Failed to create a trigger because: {}", e.what()); + // } } storage::Storage *db; @@ -307,8 +319,12 @@ class Interpreter final { InterpreterContext *interpreter_context_; - std::optional<storage::Storage::Accessor> db_accessor_; + // This cannot be std::optional because we need to move this accessor later on into a lambda capture + // which is assigned to std::function. std::function requires every object to be copyable, so we + // move this unique_ptr into a shrared_ptr. + std::unique_ptr<storage::Storage::Accessor> db_accessor_; std::optional<DbAccessor> execution_db_accessor_; + std::optional<TriggerContext> trigger_context_; bool in_explicit_transaction_{false}; bool expect_rollback_{false}; diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 567cc04b4..7fe2c7327 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -206,7 +206,10 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) SCOPED_PROFILE_OP("CreateNode"); if (input_cursor_->Pull(frame, context)) { - CreateLocalVertex(self_.node_info_, &frame, context); + auto created_vertex = CreateLocalVertex(self_.node_info_, &frame, context); + if (context.trigger_context) { + context.trigger_context->RegisterCreatedVertex(created_vertex); + } return true; } diff --git a/src/query/trigger.cpp b/src/query/trigger.cpp index 03b86ff0f..547ec06cb 100644 --- a/src/query/trigger.cpp +++ b/src/query/trigger.cpp @@ -1,28 +1,80 @@ #include "query/trigger.hpp" #include "query/context.hpp" +#include "query/cypher_query_interpreter.hpp" #include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" #include "query/interpret/frame.hpp" #include "utils/memory.hpp" namespace query { -Trigger::Trigger(std::string name, std::string query, utils::SkipList<QueryCacheEntry> *cache, - utils::SpinLock *antlr_lock) - : name_(std::move(name)), - parsed_statements_{ParseQuery(query, {} /* this should contain the predefined parameters */, cache, antlr_lock)} { + +namespace { +std::vector<std::pair<Identifier, trigger::IdentifierTag>> GetPredefinedIdentifiers() { + return {{{"createdVertices", false}, trigger::IdentifierTag::CREATED_VERTICES}}; +} +} // namespace + +void TriggerContext::RegisterCreatedVertex(const VertexAccessor created_vertex) { + created_vertices_.push_back(created_vertex); } -void Trigger::Execute(utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *dba, - utils::MonotonicBufferResource *execution_memory, const double tsc_frequency, - const double max_execution_time_sec, std::atomic<bool> *is_shutting_down) const { +TypedValue TriggerContext::GetTypedValue(const trigger::IdentifierTag tag) const { + switch (tag) { + case trigger::IdentifierTag::CREATED_VERTICES: { + std::vector<TypedValue> typed_created_vertices; + typed_created_vertices.reserve(created_vertices_.size()); + std::transform(std::begin(created_vertices_), std::end(created_vertices_), + std::back_inserter(typed_created_vertices), + [](const auto &accessor) { return TypedValue(accessor); }); + return TypedValue(typed_created_vertices); + } + } +} + +void TriggerContext::AdaptForAccessor(DbAccessor *accessor) { + // adapt created_vertices_ + auto it = created_vertices_.begin(); + for (const auto &created_vertex : created_vertices_) { + if (auto maybe_vertex = accessor->FindVertex(created_vertex.Gid(), storage::View::OLD); maybe_vertex) { + *it = *maybe_vertex; + ++it; + } + } + created_vertices_.erase(it, created_vertices_.end()); +} + +Trigger::Trigger(std::string name, const std::string &query, utils::SkipList<QueryCacheEntry> *query_cache, + utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *db_accessor, utils::SpinLock *antlr_lock) + : name_(std::move(name)), + parsed_statements_{ParseQuery(query, {}, query_cache, antlr_lock)}, + identifiers_{GetPredefinedIdentifiers()} { + // We check immediately if the query is valid by trying to create a plan. + GetPlan(plan_cache, db_accessor); +} + +std::shared_ptr<CachedPlan> Trigger::GetPlan(utils::SkipList<PlanCacheEntry> *plan_cache, + DbAccessor *db_accessor) const { AstStorage ast_storage; ast_storage.properties_ = parsed_statements_.ast_storage.properties_; ast_storage.labels_ = parsed_statements_.ast_storage.labels_; ast_storage.edge_types_ = parsed_statements_.ast_storage.edge_types_; - auto plan = CypherQueryToPlan(parsed_statements_.stripped_query.hash(), std::move(ast_storage), - utils::Downcast<CypherQuery>(parsed_statements_.query), parsed_statements_.parameters, - plan_cache, dba, parsed_statements_.is_cacheable); + std::vector<Identifier *> predefined_identifiers; + predefined_identifiers.reserve(identifiers_.size()); + std::transform(identifiers_.begin(), identifiers_.end(), std::back_inserter(predefined_identifiers), + [](auto &identifier) { return &identifier.first; }); + + return CypherQueryToPlan(parsed_statements_.stripped_query.hash(), std::move(ast_storage), + utils::Downcast<CypherQuery>(parsed_statements_.query), parsed_statements_.parameters, + plan_cache, db_accessor, parsed_statements_.is_cacheable, predefined_identifiers); +} + +void Trigger::Execute(utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *dba, + utils::MonotonicBufferResource *execution_memory, const double tsc_frequency, + const double max_execution_time_sec, std::atomic<bool> *is_shutting_down, + const TriggerContext &context) const { + auto plan = GetPlan(plan_cache, dba); + ExecutionContext ctx; ctx.db_accessor = dba; ctx.symbol_table = plan->symbol_table(); @@ -55,6 +107,14 @@ void Trigger::Execute(utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *d auto cursor = plan->plan().MakeCursor(execution_memory); Frame frame{plan->symbol_table().max_position(), execution_memory}; + for (const auto &[identifier, tag] : identifiers_) { + if (identifier.symbol_pos_ == -1) { + continue; + } + + frame[plan->symbol_table().at(identifier)] = context.GetTypedValue(tag); + } + while (cursor->Pull(frame, ctx)) ; diff --git a/src/query/trigger.hpp b/src/query/trigger.hpp index db59b3172..41fca2b14 100644 --- a/src/query/trigger.hpp +++ b/src/query/trigger.hpp @@ -1,16 +1,36 @@ #pragma once #include "query/cypher_query_interpreter.hpp" +#include "query/db_accessor.hpp" #include "query/frontend/ast/ast.hpp" namespace query { + +namespace trigger { +enum class IdentifierTag : uint8_t { CREATED_VERTICES }; +} // namespace trigger + +struct TriggerContext { + void RegisterCreatedVertex(VertexAccessor created_vertex); + + // Adapt the TriggerContext object inplace for a different DbAccessor + // (each derived accessor, e.g. VertexAccessor, gets adapted + // to the sent DbAccessor so they can be used safely) + void AdaptForAccessor(DbAccessor *accessor); + + TypedValue GetTypedValue(trigger::IdentifierTag tag) const; + + private: + std::vector<VertexAccessor> created_vertices_; +}; + struct Trigger { - explicit Trigger(std::string name, std::string query, utils::SkipList<QueryCacheEntry> *cache, - utils::SpinLock *antlr_lock); + explicit Trigger(std::string name, const std::string &query, utils::SkipList<QueryCacheEntry> *query_cache, + utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *db_accessor, utils::SpinLock *antlr_lock); void Execute(utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *dba, utils::MonotonicBufferResource *execution_memory, double tsc_frequency, double max_execution_time_sec, - std::atomic<bool> *is_shutting_down) const; + std::atomic<bool> *is_shutting_down, const TriggerContext &context) const; bool operator==(const Trigger &other) const { return name_ == other.name_; } // NOLINTNEXTLINE (modernize-use-nullptr) @@ -19,10 +39,14 @@ struct Trigger { // NOLINTNEXTLINE (modernize-use-nullptr) bool operator<(const std::string &other) const { return name_ < other; } - const auto &name() const { return name_; } + const auto &name() const noexcept { return name_; } private: + std::shared_ptr<CachedPlan> GetPlan(utils::SkipList<PlanCacheEntry> *plan_cache, DbAccessor *db_accessor) const; + std::string name_; ParsedQuery parsed_statements_; + + mutable std::vector<std::pair<Identifier, trigger::IdentifierTag>> identifiers_; }; } // namespace query diff --git a/src/storage/v2/storage.cpp b/src/storage/v2/storage.cpp index b699dc250..d526b328a 100644 --- a/src/storage/v2/storage.cpp +++ b/src/storage/v2/storage.cpp @@ -404,17 +404,22 @@ Storage::Accessor::Accessor(Storage *storage) Storage::Accessor::Accessor(Accessor &&other) noexcept : storage_(other.storage_), + storage_guard_(std::move(other.storage_guard_)), transaction_(std::move(other.transaction_)), + commit_timestamp_(other.commit_timestamp_), is_transaction_active_(other.is_transaction_active_), config_(other.config_) { // Don't allow the other accessor to abort our transaction in destructor. other.is_transaction_active_ = false; + other.commit_timestamp_.reset(); } Storage::Accessor::~Accessor() { if (is_transaction_active_) { Abort(); } + + FinalizeTransaction(); } VertexAccessor Storage::Accessor::CreateVertex() { @@ -793,11 +798,10 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( // Save these so we can mark them used in the commit log. uint64_t start_timestamp = transaction_.start_timestamp; - uint64_t commit_timestamp; { std::unique_lock<utils::SpinLock> engine_guard(storage_->engine_lock_); - commit_timestamp = storage_->CommitTimestamp(desired_commit_timestamp); + commit_timestamp_.emplace(storage_->CommitTimestamp(desired_commit_timestamp)); // Before committing and validating vertices against unique constraints, // we have to update unique constraints with the vertices that are going @@ -821,7 +825,7 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( // No need to take any locks here because we modified this vertex and no // one else can touch it until we commit. unique_constraint_violation = - storage_->constraints_.unique_constraints.Validate(*prev.vertex, transaction_, commit_timestamp); + storage_->constraints_.unique_constraints.Validate(*prev.vertex, transaction_, *commit_timestamp_); if (unique_constraint_violation) { break; } @@ -838,7 +842,7 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( // Replica can log only the write transaction received from Main // so the Wal files are consistent if (storage_->replication_role_ == ReplicationRole::MAIN || desired_commit_timestamp.has_value()) { - storage_->AppendToWal(transaction_, commit_timestamp); + storage_->AppendToWal(transaction_, *commit_timestamp_); } // Take committed_transactions lock while holding the engine lock to @@ -848,12 +852,12 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( // TODO: release lock, and update all deltas to have a local copy // of the commit timestamp MG_ASSERT(transaction_.commit_timestamp != nullptr, "Invalid database state!"); - transaction_.commit_timestamp->store(commit_timestamp, std::memory_order_release); + transaction_.commit_timestamp->store(*commit_timestamp_, std::memory_order_release); // Replica can only update the last commit timestamp with // the commits received from main. if (storage_->replication_role_ == ReplicationRole::MAIN || desired_commit_timestamp.has_value()) { // Update the last commit timestamp - storage_->last_commit_timestamp_.store(commit_timestamp); + storage_->last_commit_timestamp_.store(*commit_timestamp_); } // Release engine lock because we don't have to hold it anymore // and emplace back could take a long time. @@ -862,13 +866,11 @@ utils::BasicResult<ConstraintViolation, void> Storage::Accessor::Commit( }); storage_->commit_log_->MarkFinished(start_timestamp); - storage_->commit_log_->MarkFinished(commit_timestamp); } } if (unique_constraint_violation) { Abort(); - storage_->commit_log_->MarkFinished(commit_timestamp); return *unique_constraint_violation; } } @@ -1041,6 +1043,13 @@ void Storage::Accessor::Abort() { is_transaction_active_ = false; } +void Storage::Accessor::FinalizeTransaction() { + if (commit_timestamp_) { + storage_->commit_log_->MarkFinished(*commit_timestamp_); + commit_timestamp_.reset(); + } +} + const std::string &Storage::LabelToName(LabelId label) const { return name_id_mapper_.IdToName(label.AsUint()); } const std::string &Storage::PropertyToName(PropertyId property) const { diff --git a/src/storage/v2/storage.hpp b/src/storage/v2/storage.hpp index 6e3275caa..2d1b77f70 100644 --- a/src/storage/v2/storage.hpp +++ b/src/storage/v2/storage.hpp @@ -300,6 +300,8 @@ class Storage final { /// @throw std::bad_alloc void Abort(); + void FinalizeTransaction(); + private: /// @throw std::bad_alloc VertexAccessor CreateVertex(storage::Gid gid); @@ -310,6 +312,7 @@ class Storage final { Storage *storage_; std::shared_lock<utils::RWLock> storage_guard_; Transaction transaction_; + std::optional<uint64_t> commit_timestamp_; bool is_transaction_active_; Config::Items config_; }; diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 8f8c6eb66..67698075f 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -463,7 +463,7 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate #define MATCH(...) query::test_common::GetWithPatterns(storage.Create<query::Match>(), {__VA_ARGS__}) #define WHERE(expr) storage.Create<query::Where>((expr)) #define CREATE(...) query::test_common::GetWithPatterns(storage.Create<query::Create>(), {__VA_ARGS__}) -#define IDENT(name) storage.Create<query::Identifier>((name)) +#define IDENT(...) storage.Create<query::Identifier>(__VA_ARGS__) #define LITERAL(val) storage.Create<query::PrimitiveLiteral>((val)) #define LIST(...) storage.Create<query::ListLiteral>(std::vector<query::Expression *>{__VA_ARGS__}) #define MAP(...) \ diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 4d243382a..1e19f6594 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" +#include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" @@ -1093,3 +1094,53 @@ TEST(TestSymbolTable, CreateAnonymousSymbolWithExistingUserSymbolCalledAnon) { auto anon2 = symbol_table.CreateAnonymousSymbol(); ASSERT_EQ(anon2.name_, "anon2"); } + +TEST_F(TestSymbolGenerator, PredefinedIdentifiers) { + auto *first_op = IDENT("first_op", false); + auto *second_op = IDENT("second_op", false); + // RETURN first_op + second_op AS result + auto query = QUERY(SINGLE_QUERY(RETURN(ADD(first_op, second_op), AS("result")))); + EXPECT_THROW(query::MakeSymbolTable(query), SemanticException); + EXPECT_THROW(query::MakeSymbolTable(query, {first_op}), SemanticException); + EXPECT_THROW(query::MakeSymbolTable(query, {second_op}), SemanticException); + auto symbol_table = query::MakeSymbolTable(query, {first_op, second_op}); + ASSERT_EQ(symbol_table.max_position(), 3); + + // predefined identifier can only be used in one scope + // RETURN first_op + second_op AS result UNION RETURN second_op + first_op AS result + query = QUERY(SINGLE_QUERY(RETURN(ADD(first_op, second_op), AS("result"))), + UNION(SINGLE_QUERY(RETURN(ADD(second_op, first_op), AS("result"))))); + ASSERT_THROW(query::MakeSymbolTable(query, {first_op, second_op}), SemanticException); + + // predefined identifier can be introduced in any of the scope + // different predefined identifiers can be introduced in different scopes + // RETURN first_op AS result UNION RETURN second_op AS result + query = QUERY(SINGLE_QUERY(RETURN(first_op, AS("result"))), UNION(SINGLE_QUERY(RETURN(second_op, AS("result"))))); + ASSERT_THROW(query::MakeSymbolTable(query), SemanticException); + symbol_table = query::MakeSymbolTable(query, {first_op, second_op}); + ASSERT_EQ(symbol_table.max_position(), 5); + + // WITH statement resets the scope, but the predefined identifier is okay + // because it's the first introduction of it in the query + // WITH 1 as one RETURN first_op AS first + query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), AS("one")), RETURN(first_op, AS("first")))); + ASSERT_THROW(query::MakeSymbolTable(query), SemanticException); + symbol_table = query::MakeSymbolTable(query, {first_op}); + ASSERT_EQ(symbol_table.max_position(), 3); + + // In the first scope, first_op represents identifier created by match, + // in the second it represent the predefined identifier + // MATCH(first_op) WITH first_op as n RETURN first_op, n + query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("first_op"))), WITH("first_op", AS("n")), RETURN("first_op", "n"))); + ASSERT_THROW(query::MakeSymbolTable(query), SemanticException); + symbol_table = query::MakeSymbolTable(query, {first_op}); + ASSERT_EQ(symbol_table.max_position(), 6); + + // You cannot redaclare the predefined identifier in the same scope + // UNWIND first_op as u CREATE(first_op {prop: u}) + auto unwind = UNWIND(first_op, AS("u")); + auto node = NODE("first_op"); + node->properties_[storage.GetPropertyIx("prop")] = dynamic_cast<Identifier *>(unwind->named_expression_->expression_); + query = QUERY(SINGLE_QUERY(unwind, CREATE(PATTERN(node)))); + ASSERT_THROW(query::MakeSymbolTable(query, {first_op}), SemanticException); +}