From f360555e1b9c7073235aa34b0ffe19470cdb80b6 Mon Sep 17 00:00:00 2001 From: DavIvek Date: Mon, 26 Feb 2024 21:01:03 +0100 Subject: [PATCH] create label via variable --- src/query/frontend/ast/ast.hpp | 8 ++- .../frontend/ast/cypher_main_visitor.cpp | 23 ++++---- .../frontend/opencypher/grammar/Cypher.g4 | 2 +- src/query/plan/operator.cpp | 54 ++++++++++++++----- src/query/plan/operator.hpp | 12 +++-- src/query/plan/preprocess.cpp | 14 +++-- src/query/plan/pretty_print.cpp | 12 ++++- src/query/plan/rule_based_planner.hpp | 18 +++++-- 8 files changed, 101 insertions(+), 42 deletions(-) diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 5e97332d8..5dc1b1c18 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1770,7 +1770,7 @@ class NodeAtom : public memgraph::query::PatternAtom { return visitor.PostVisit(*this); } - std::vector labels_; + std::vector> labels_; std::variant, memgraph::query::ParameterLookup *> properties_; @@ -1780,7 +1780,11 @@ class NodeAtom : public memgraph::query::PatternAtom { object->identifier_ = identifier_ ? identifier_->Clone(storage) : nullptr; object->labels_.resize(labels_.size()); for (auto i = 0; i < object->labels_.size(); ++i) { - object->labels_[i] = storage->GetLabelIx(labels_[i].name); + if (const auto *label = std::get_if(&labels_[i])) { + object->labels_[i] = storage->GetLabelIx(label->name); + } else { + object->labels_[i] = std::get(labels_[i])->Clone(storage); + } } if (const auto *properties = std::get_if>(&properties_)) { auto &new_obj_properties = std::get>(object->properties_); diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index d3747bc3f..a0f521619 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -1912,7 +1912,7 @@ antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternCon anonymous_identifiers.push_back(&node->identifier_); } if (ctx->nodeLabels()) { - node->labels_ = std::any_cast>(ctx->nodeLabels()->accept(this)); + node->labels_ = std::any_cast>>(ctx->nodeLabels()->accept(this)); } if (ctx->properties()) { // This can return either properties or parameters @@ -1926,16 +1926,21 @@ antlrcpp::Any CypherMainVisitor::visitNodePattern(MemgraphCypher::NodePatternCon } antlrcpp::Any CypherMainVisitor::visitNodeLabels(MemgraphCypher::NodeLabelsContext *ctx) { - std::vector labels; + std::vector> labels; for (auto *node_label : ctx->nodeLabel()) { - if (node_label->labelName()->symbolicName()) { - labels.emplace_back(AddLabel(std::any_cast(node_label->accept(this)))); + if (node_label->labelName()) { + if (node_label->labelName()->symbolicName()) { + labels.emplace_back(AddLabel(std::any_cast(node_label->accept(this)))); + } else { + // If we have a parameter, we have to resolve it. + const auto *param_lookup = std::any_cast(node_label->accept(this)); + const auto label_name = parameters_->AtTokenPosition(param_lookup->token_position_).ValueString(); + labels.emplace_back(storage_->GetLabelIx(label_name)); + query_info_.is_cacheable = false; // We can't cache queries with label parameters. + } } else { - // If we have a parameter, we have to resolve it. - const auto *param_lookup = std::any_cast(node_label->accept(this)); - const auto label_name = parameters_->AtTokenPosition(param_lookup->token_position_).ValueString(); - labels.emplace_back(storage_->GetLabelIx(label_name)); - query_info_.is_cacheable = false; // We can't cache queries with label parameters. + // expression + labels.emplace_back(std::any_cast(node_label->accept(this))); } } return labels; diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index 55cb53ef3..84f959ccf 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -191,7 +191,7 @@ relationshipTypes : ':' relTypeName ( '|' ':'? relTypeName )* ; nodeLabels : nodeLabel ( nodeLabel )* ; -nodeLabel : ':' labelName ; +nodeLabel : ':' (labelName | expression); labelName : symbolicName | parameter; diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 75b531261..4f6a74e09 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -216,8 +216,21 @@ VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *fram auto &dba = *context.db_accessor; auto new_node = dba.InsertVertex(); context.execution_stats[ExecutionStats::Key::CREATED_NODES] += 1; + // Evaluator should use the latest accessors, as modified in this query, when + // setting properties on new nodes. + ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::View::NEW); for (auto label : node_info.labels) { - auto maybe_error = new_node.AddLabel(label); + auto maybe_error = std::invoke([&] { + if (const auto *label_atom = std::get_if(&label)) { + return new_node.AddLabel(*label_atom); + } else { + // auto key = evaluator.Visit(*std::get(label)); + // return new_node.AddLabel(dba.NameToLabel(key.ValueString())); + auto expression = std::get(label); + return new_node.AddLabel(dba.NameToLabel(expression->Accept(evaluator).ValueString())); + } + }); if (maybe_error.HasError()) { switch (maybe_error.GetError()) { case storage::Error::SERIALIZATION_ERROR: @@ -232,10 +245,6 @@ VertexAccessor &CreateLocalVertex(const NodeCreationInfo &node_info, Frame *fram } context.execution_stats[ExecutionStats::Key::CREATED_LABELS] += 1; } - // Evaluator should use the latest accessors, as modified in this query, when - // setting properties on new nodes. - ExpressionEvaluator evaluator(frame, context.symbol_table, context.evaluation_context, context.db_accessor, - storage::View::NEW); // TODO: PropsSetChecked allocates a PropertyValue, make it use context.memory // when we update PropertyValue with custom allocator. std::map properties; @@ -275,10 +284,22 @@ CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, utils::Me bool CreateNode::CreateNodeCursor::Pull(Frame &frame, ExecutionContext &context) { OOMExceptionEnabler oom_exception; SCOPED_PROFILE_OP("CreateNode"); + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::View::NEW); + std::vector labels; + for (auto &label : self_.node_info_.labels) { + if (const auto *label_atom = std::get_if(&label)) { + labels.emplace_back(*label_atom); + } else { + // auto key = evaluator.Visit(*std::get(label)); + // labels.emplace_back(context.db_accessor->NameToLabel(key.ValueString())); + auto expression = std::get(label); + labels.emplace_back(context.db_accessor->NameToLabel(expression->Accept(evaluator).ValueString())); + } + } #ifdef MG_ENTERPRISE if (license::global_license_checker.IsEnterpriseValidFast() && context.auth_checker && - !context.auth_checker->Has(self_.node_info_.labels, - memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) { + !context.auth_checker->Has(labels, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE)) { throw QueryRuntimeException("Vertex not created due to not having enough permission!"); } #endif @@ -368,8 +389,19 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &cont SCOPED_PROFILE_OP_BY_REF(self_); if (!input_cursor_->Pull(frame, context)) return false; + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, + storage::View::NEW); #ifdef MG_ENTERPRISE + std::vector labels; + for (auto label : self_.node_info_.labels) { + if (const auto *label_atom = std::get_if(&label)) { + labels.emplace_back(*label_atom); + } else { + auto expression = std::get(label); + labels.emplace_back(context.db_accessor->NameToLabel(expression->Accept(evaluator).ValueString())); + } + } if (license::global_license_checker.IsEnterpriseValidFast()) { const auto fine_grained_permission = self_.existing_node_ ? memgraph::query::AuthQuery::FineGrainedPrivilege::UPDATE @@ -379,7 +411,7 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &cont if (context.auth_checker && !(context.auth_checker->Has(self_.edge_info_.edge_type, memgraph::query::AuthQuery::FineGrainedPrivilege::CREATE_DELETE) && - context.auth_checker->Has(self_.node_info_.labels, fine_grained_permission))) { + context.auth_checker->Has(labels, fine_grained_permission))) { throw QueryRuntimeException("Edge not created due to not having enough permission!"); } } @@ -389,12 +421,6 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, ExecutionContext &cont ExpectType(self_.input_symbol_, vertex_value, TypedValue::Type::Vertex); auto &v1 = vertex_value.ValueVertex(); - // Similarly to CreateNode, newly created edges and nodes should use the - // storage::View::NEW. - // E.g. we pickup new properties: `CREATE (n {p: 42}) -[:r {ep: n.p}]-> ()` - ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, - storage::View::NEW); - // get the destination vertex (possibly an existing node) auto &v2 = OtherVertex(frame, context); diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 516ef2e38..9f240fda3 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -290,18 +290,20 @@ struct NodeCreationInfo { NodeCreationInfo() = default; - NodeCreationInfo(Symbol symbol, std::vector labels, + NodeCreationInfo(Symbol symbol, std::vector> labels, std::variant properties) : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {}; - NodeCreationInfo(Symbol symbol, std::vector labels, PropertiesMapList properties) + NodeCreationInfo(Symbol symbol, std::vector> labels, + PropertiesMapList properties) : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{std::move(properties)} {}; - NodeCreationInfo(Symbol symbol, std::vector labels, ParameterLookup *properties) + NodeCreationInfo(Symbol symbol, std::vector> labels, + ParameterLookup *properties) : symbol{std::move(symbol)}, labels{std::move(labels)}, properties{properties} {}; Symbol symbol; - std::vector labels; + std::vector> labels; std::variant properties; NodeCreationInfo Clone(AstStorage *storage) const { diff --git a/src/query/plan/preprocess.cpp b/src/query/plan/preprocess.cpp index c3bfdf462..c018187fc 100644 --- a/src/query/plan/preprocess.cpp +++ b/src/query/plan/preprocess.cpp @@ -358,11 +358,17 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table, }; auto add_node_filter = [&](NodeAtom *node) { const auto &node_symbol = symbol_table.at(*node->identifier_); - if (!node->labels_.empty()) { - // Create a LabelsTest and store it. - auto *labels_test = storage.Create(node->identifier_, node->labels_); + std::vector labels; + for (auto label : node->labels_) { + if (const auto *label_node = std::get_if(&label)) { + throw SemanticException("Parameter lookup not supported in MATCH/MERGE clause!"); + } + labels.push_back(std::get(label)); + } + if (!labels.empty()) { + auto *labels_test = storage.Create(node->identifier_, labels); auto label_filter = FilterInfo{FilterInfo::Type::Label, labels_test, std::unordered_set{node_symbol}}; - label_filter.labels = node->labels_; + label_filter.labels = labels; all_filters_.emplace_back(label_filter); } add_properties(node); diff --git a/src/query/plan/pretty_print.cpp b/src/query/plan/pretty_print.cpp index a2df9422c..082711c5c 100644 --- a/src/query/plan/pretty_print.cpp +++ b/src/query/plan/pretty_print.cpp @@ -1,4 +1,4 @@ -// Copyright 2023 Memgraph Ltd. +// Copyright 2024 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -337,7 +337,15 @@ json ToJson(const std::vector> &pro json ToJson(const NodeCreationInfo &node_info, const DbAccessor &dba) { json self; self["symbol"] = ToJson(node_info.symbol); - self["labels"] = ToJson(node_info.labels, dba); + std::vector labels; + for (auto label : node_info.labels) { + if (const auto *label_node = std::get_if(&label)) { + labels = {}; + break; + } + labels.push_back(std::get(label)); + } + self["labels"] = ToJson(labels, dba); const auto *props = std::get_if(&node_info.properties); self["properties"] = ToJson(props ? *props : PropertiesMapList{}, dba); return self; diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 7fba3b623..3b369e51b 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -311,11 +311,19 @@ class RuleBasedPlanner { std::unordered_set &bound_symbols) { auto node_to_creation_info = [&](const NodeAtom &node) { const auto &node_symbol = symbol_table.at(*node.identifier_); - std::vector labels; - labels.reserve(node.labels_.size()); - for (const auto &label : node.labels_) { - labels.push_back(GetLabel(label)); - } + + auto labels = std::invoke([&]() -> std::vector> { + std::vector> labels; + labels.reserve(node.labels_.size()); + for (const auto &label : node.labels_) { + if (const auto *label_atom = std::get_if(&label)) { + labels.emplace_back(GetLabel(*label_atom)); + } else { + labels.emplace_back(std::get(label)); + } + } + return labels; + }); auto properties = std::invoke([&]() -> std::variant { if (const auto *node_properties =