diff --git a/src/query/v2/plan/pretty_print.cpp b/src/query/v2/plan/pretty_print.cpp index 76cc6ac26..4be8f4884 100644 --- a/src/query/v2/plan/pretty_print.cpp +++ b/src/query/v2/plan/pretty_print.cpp @@ -86,6 +86,14 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) { return true; } +bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByPrimaryKey &op) { + WithPrintLn([&](auto &out) { + out << "* ScanAllByPrimaryKey" + << " (" << op.output_symbol_.name() << " :" << request_manager_->LabelToName(op.label_) << ")"; + }); + return true; +} + bool PlanPrinter::PreVisit(query::v2::plan::Expand &op) { WithPrintLn([&](auto &out) { *out_ << "* Expand (" << op.input_symbol_.name() << ")" @@ -480,6 +488,19 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) { return false; } +bool PlanToJsonVisitor::PreVisit(ScanAllByPrimaryKey &op) { + json self; + self["name"] = "ScanAllByPrimaryKey"; + self["label"] = ToJson(op.label_, *request_manager_); + self["output_symbol"] = ToJson(op.output_symbol_); + + op.input_->Accept(*this); + self["input"] = PopOutput(); + + output_ = std::move(self); + return false; +} + bool PlanToJsonVisitor::PreVisit(CreateNode &op) { json self; self["name"] = "CreateNode"; diff --git a/src/query/v2/plan/pretty_print.hpp b/src/query/v2/plan/pretty_print.hpp index 66fa31556..c1c84db40 100644 --- a/src/query/v2/plan/pretty_print.hpp +++ b/src/query/v2/plan/pretty_print.hpp @@ -68,6 +68,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; + bool PreVisit(query::v2::plan::ScanAllByPrimaryKey &) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; @@ -195,6 +196,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelProperty &) override; + bool PreVisit(ScanAllByPrimaryKey &) override; bool PreVisit(Produce &) override; bool PreVisit(Accumulate &) override; diff --git a/src/query/v2/plan/rewrite/index_lookup.hpp b/src/query/v2/plan/rewrite/index_lookup.hpp index 41e24d2b6..1a4a4160c 100644 --- a/src/query/v2/plan/rewrite/index_lookup.hpp +++ b/src/query/v2/plan/rewrite/index_lookup.hpp @@ -273,6 +273,16 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } + bool PreVisit(ScanAllByPrimaryKey &op) override { + prev_ops_.push_back(&op); + return true; + } + + bool PostVisit(ScanAllByPrimaryKey &) override { + prev_ops_.pop_back(); + return true; + } + bool PreVisit(ConstructNamedPath &op) override { prev_ops_.push_back(&op); return true; @@ -480,6 +490,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { storage::v3::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); } + void EraseLabelFilters(const memgraph::query::v2::Symbol &node_symbol, memgraph::query::v2::LabelIx prim_label) { + std::vector removed_expressions; + filters_.EraseLabelFilter(node_symbol, prim_label, &removed_expressions); + filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + } + std::optional FindBestLabelIndex(const std::unordered_set &labels) { MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels."); std::optional best_label; @@ -564,19 +580,28 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { // First, try to see if we can find a vertex based on the possibly // supplied primary key. auto property_filters = filters_.PropertyFilters(node_symbol); - storage::v3::LabelId prim_label; - std::vector primary_key; + query::v2::LabelIx prim_label; + std::vector> primary_key; if (!property_filters.empty()) { for (const auto &label : labels) { if (db_->LabelIndexExists(GetLabel(label))) { - prim_label = GetLabel(label); - primary_key = db_->ExtractPrimaryKey(prim_label, property_filters); + prim_label = label; + primary_key = db_->ExtractPrimaryKey(GetLabel(prim_label), property_filters); break; } } if (!primary_key.empty()) { - return std::make_unique(input, node_symbol, prim_label, primary_key); + // Mark the expressions so they won't be used for an additional, unnecessary filter. + for (const auto &pk : primary_key) { + filter_exprs_for_removal_.insert(pk.first); + filters_.EraseFilter(pk.second); + } + EraseLabelFilters(node_symbol, prim_label); + std::vector pk_expressions; + std::transform(primary_key.begin(), primary_key.end(), std::back_inserter(pk_expressions), + [](const auto &exp) { return exp.first; }); + return std::make_unique(input, node_symbol, GetLabel(prim_label), pk_expressions); } } @@ -593,9 +618,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { filter_exprs_for_removal_.insert(found_index->filter.expression); } filters_.EraseFilter(found_index->filter); - std::vector removed_expressions; - filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions); - filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + EraseLabelFilters(node_symbol, found_index->label); if (prop_filter.lower_bound_ || prop_filter.upper_bound_) { return std::make_unique( input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), diff --git a/src/query/v2/plan/vertex_count_cache.hpp b/src/query/v2/plan/vertex_count_cache.hpp index 70b8d324b..120e434ea 100644 --- a/src/query/v2/plan/vertex_count_cache.hpp +++ b/src/query/v2/plan/vertex_count_cache.hpp @@ -58,9 +58,9 @@ class VertexCountCache { bool LabelPropertyIndexExists(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/) { return false; } - std::vector ExtractPrimaryKey(storage::v3::LabelId label, - std::vector property_filters) { - std::vector pk; + std::vector> ExtractPrimaryKey( + storage::v3::LabelId label, std::vector property_filters) { + std::vector> pk; const auto schema = shard_request_manager_->GetSchemaForLabel(label); std::vector schema_properties; @@ -72,11 +72,13 @@ class VertexCountCache { for (const auto &property_filter : property_filters) { const auto &property_id = NameToProperty(property_filter.property_filter->property_.name); if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { - pk.push_back(property_filter.expression); + pk.emplace_back(std::make_pair(property_filter.expression, property_filter)); } } - return pk.size() == schema_properties.size() ? pk : std::vector{}; + return pk.size() == schema_properties.size() + ? pk + : std::vector>{}; } msgs::ShardRequestManagerInterface *shard_request_manager_; diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index d9f9df875..123a448f5 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -93,8 +93,8 @@ target_link_libraries(${test_prefix}query_expression_evaluator mg-query) add_unit_test(query_plan.cpp) target_link_libraries(${test_prefix}query_plan mg-query) -add_unit_test(query_plan_v2.cpp) -target_link_libraries(${test_prefix}query_plan_v2 mg-query-v2) +add_unit_test(query_v2_plan.cpp) +target_link_libraries(${test_prefix}query_v2_plan mg-query-v2) add_unit_test(query_plan_accumulate_aggregate.cpp) target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 4697124ef..f7fe6e805 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -84,38 +84,14 @@ struct OrderBy { std::vector expressions; }; -// new stuff begin - -struct OrderByv2 { - std::vector expressions; -}; - -// new stuff end - struct Skip { Expression *expression = nullptr; }; -// new stuff begin - -struct Skipv2 { - query::v2::Expression *expression = nullptr; -}; - -// new stuff end - struct Limit { Expression *expression = nullptr; }; -// new stuff begin - -struct Limitv2 { - query::v2::Expression *expression = nullptr; -}; - -// new stuff end - struct OnMatch { std::vector set; }; @@ -182,42 +158,6 @@ auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, return storage.Create(expr, storage.GetPropertyIx(prop_pair.first)); } -// new stuff begin - -template -auto GetPropertyLookup(memgraph::query::v2::AstStorage &storage, TDbAccessor &dba, const std::string &name, - memgraph::storage::v3::PropertyId property) { - return storage.Create(storage.Create(name), - storage.GetPropertyIx(dba.PropertyToName(property))); -} - -template -auto GetPropertyLookup(memgraph::query::v2::AstStorage &storage, TDbAccessor &dba, - memgraph::query::v2::Expression *expr, memgraph::storage::v3::PropertyId property) { - return storage.Create(expr, storage.GetPropertyIx(dba.PropertyToName(property))); -} - -template -auto GetPropertyLookup(memgraph::query::v2::AstStorage &storage, TDbAccessor &dba, - memgraph::query::v2::Expression *expr, const std::string &property) { - return storage.Create(expr, storage.GetPropertyIx(property)); -} - -template -auto GetPropertyLookup(memgraph::query::v2::AstStorage &storage, TDbAccessor &, const std::string &name, - const std::pair &prop_pair) { - return storage.Create(storage.Create(name), - storage.GetPropertyIx(prop_pair.first)); -} - -template -auto GetPropertyLookup(memgraph::query::v2::AstStorage &storage, TDbAccessor &, memgraph::query::v2::Expression *expr, - const std::pair &prop_pair) { - return storage.Create(expr, storage.GetPropertyIx(prop_pair.first)); -} - -// new stuff end - /// Create an EdgeAtom with given name, direction and edge_type. /// /// Name is used to create the Identifier which is assigned to the edge. @@ -276,15 +216,6 @@ auto GetNode(AstStorage &storage, const std::string &name, std::optional label = std::nullopt) { - auto node = storage.Create(storage.Create(name)); - if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); - return node; -} -// new stuff end - /// Create a Pattern with given atoms. auto GetPattern(AstStorage &storage, std::vector atoms) { auto pattern = storage.Create(); @@ -301,26 +232,6 @@ auto GetPattern(AstStorage &storage, const std::string &name, std::vector atoms) { - auto pattern = storage.Create(); - pattern->identifier_ = storage.Create(memgraph::utils::RandomString(20), false); - pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); - return pattern; -} - -/// Create a Pattern with given name and atoms. -auto GetPattern(memgraph::query::v2::AstStorage &storage, const std::string &name, - std::vector atoms) { - auto pattern = storage.Create(); - pattern->identifier_ = storage.Create(name, true); - pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); - return pattern; -} - -// new stuff end - /// This function fills an AST node which with given patterns. /// /// The function is most commonly used to create Match and Create clauses. @@ -330,16 +241,6 @@ auto GetWithPatterns(TWithPatterns *with_patterns, std::vector patter return with_patterns; } -// new stuff begin - -template -auto GetWithPatterns(TWithPatterns *with_patterns, std::vector patterns) { - with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); - return with_patterns; -} - -// new stuff end - /// Create a query with given clauses. auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { @@ -375,45 +276,6 @@ auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { return GetSingleQuery(single_query, clauses...); } -// new stuff begin - -auto GetSingleQuery(query::v2::SingleQuery *single_query, query::v2::Clause *clause) { - single_query->clauses_.emplace_back(clause); - return single_query; -} -auto GetSingleQuery(query::v2::SingleQuery *single_query, query::v2::Match *match, query::v2::Where *where) { - match->where_ = where; - single_query->clauses_.emplace_back(match); - return single_query; -} -auto GetSingleQuery(query::v2::SingleQuery *single_query, query::v2::With *with, query::v2::Where *where) { - with->where_ = where; - single_query->clauses_.emplace_back(with); - return single_query; -} -template -auto GetSingleQuery(query::v2::SingleQuery *single_query, query::v2::Match *match, query::v2::Where *where, - T *...clauses) { - match->where_ = where; - single_query->clauses_.emplace_back(match); - return GetSingleQuery(single_query, clauses...); -} -template -auto GetSingleQuery(query::v2::SingleQuery *single_query, query::v2::With *with, query::v2::Where *where, - T *...clauses) { - with->where_ = where; - single_query->clauses_.emplace_back(with); - return GetSingleQuery(single_query, clauses...); -} - -template -auto GetSingleQuery(query::v2::SingleQuery *single_query, query::v2::Clause *clause, T *...clauses) { - single_query->clauses_.emplace_back(clause); - return GetSingleQuery(single_query, clauses...); -} - -// new stuff end - auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { cypher_union->single_query_ = single_query; return cypher_union; @@ -433,24 +295,6 @@ auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_union return query; } -// new stuff begin - -auto GetQuery(query::v2::AstStorage &storage, query::v2::SingleQuery *single_query) { - auto *query = storage.Create(); - query->single_query_ = single_query; - return query; -} - -template -auto GetQuery(query::v2::AstStorage &storage, query::v2::SingleQuery *single_query, T *...cypher_unions) { - auto *query = storage.Create(); - query->single_query_ = single_query; - query->cypher_unions_ = std::vector{cypher_unions...}; - return query; -} - -// new stuff end - // Helper functions for constructing RETURN and WITH clauses. void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { body.named_expressions.emplace_back(named_expr); @@ -514,80 +358,6 @@ void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &na FillReturnBody(storage, body, rest...); } -// new stuff begin - -void FillReturnBody(query::v2::AstStorage &, query::v2::ReturnBody &body, query::v2::NamedExpression *named_expr) { - body.named_expressions.emplace_back(named_expr); -} -void FillReturnBody(query::v2::AstStorage &storage, query::v2::ReturnBody &body, const std::string &name) { - if (name == "*") { - body.all_identifiers = true; - } else { - auto *ident = storage.Create(name); - auto *named_expr = storage.Create(name, ident); - body.named_expressions.emplace_back(named_expr); - } -} -void FillReturnBody(query::v2::AstStorage &, query::v2::ReturnBody &body, Limitv2 limit) { - body.limit = limit.expression; -} -void FillReturnBody(query::v2::AstStorage &, query::v2::ReturnBody &body, Skipv2 skip, Limitv2 limit = Limitv2{}) { - body.skip = skip.expression; - body.limit = limit.expression; -} -void FillReturnBody(query::v2::AstStorage &, query::v2::ReturnBody &body, OrderByv2 order_by, - Limitv2 limit = Limitv2{}) { - body.order_by = order_by.expressions; - body.limit = limit.expression; -} -void FillReturnBody(query::v2::AstStorage &, query::v2::ReturnBody &body, OrderByv2 order_by, Skipv2 skip, - Limitv2 limit = Limitv2{}) { - body.order_by = order_by.expressions; - body.skip = skip.expression; - body.limit = limit.expression; -} -void FillReturnBody(query::v2::AstStorage &, query::v2::ReturnBody &body, query::v2::Expression *expr, - query::v2::NamedExpression *named_expr) { - // This overload supports `RETURN(expr, AS(name))` construct, since - // NamedExpression does not inherit Expression. - named_expr->expression_ = expr; - body.named_expressions.emplace_back(named_expr); -} -void FillReturnBody(query::v2::AstStorage &storage, query::v2::ReturnBody &body, const std::string &name, - query::v2::NamedExpression *named_expr) { - named_expr->expression_ = storage.Create(name); - body.named_expressions.emplace_back(named_expr); -} -template -void FillReturnBody(query::v2::AstStorage &storage, query::v2::ReturnBody &body, query::v2::Expression *expr, - query::v2::NamedExpression *named_expr, T... rest) { - named_expr->expression_ = expr; - body.named_expressions.emplace_back(named_expr); - FillReturnBody(storage, body, rest...); -} -template -void FillReturnBody(query::v2::AstStorage &storage, query::v2::ReturnBody &body, query::v2::NamedExpression *named_expr, - T... rest) { - body.named_expressions.emplace_back(named_expr); - FillReturnBody(storage, body, rest...); -} -template -void FillReturnBody(query::v2::AstStorage &storage, query::v2::ReturnBody &body, const std::string &name, - query::v2::NamedExpression *named_expr, T... rest) { - named_expr->expression_ = storage.Create(name); - body.named_expressions.emplace_back(named_expr); - FillReturnBody(storage, body, rest...); -} -template -void FillReturnBody(query::v2::AstStorage &storage, query::v2::ReturnBody &body, const std::string &name, T... rest) { - auto *ident = storage.Create(name); - auto *named_expr = storage.Create(name, ident); - body.named_expressions.emplace_back(named_expr); - FillReturnBody(storage, body, rest...); -} - -// new stuff end - /// Create the return clause with given expressions. /// /// The supported expression combination of arguments is: @@ -609,18 +379,6 @@ auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { return ret; } -// new stuff begin - -template -auto GetReturn(query::v2::AstStorage &storage, bool distinct, T... exprs) { - auto ret = storage.Create(); - ret->body_.distinct = distinct; - FillReturnBody(storage, ret->body_, exprs...); - return ret; -} - -// new stuff end - /// Create the with clause with given expressions. /// /// The supported expression combination is the same as for @c GetReturn. @@ -739,10 +497,7 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec memgraph::query::test_common::GetWithPatterns(storage.Create(true), {__VA_ARGS__}) #define MATCH(...) \ memgraph::query::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) -#define MATCH_V2(...) \ - memgraph::query::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) #define WHERE(expr) storage.Create((expr)) -#define WHERE_V2(expr) storage.Create((expr)) #define CREATE(...) \ memgraph::query::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) #define IDENT(...) storage.Create(__VA_ARGS__) @@ -789,8 +544,6 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec std::vector{(property)}) #define QUERY(...) memgraph::query::test_common::GetQuery(storage, __VA_ARGS__) #define SINGLE_QUERY(...) memgraph::query::test_common::GetSingleQuery(storage.Create(), __VA_ARGS__) -#define SINGLE_QUERY_V2(...) \ - memgraph::query::test_common::GetSingleQuery(storage.Create(), __VA_ARGS__) #define UNION(...) memgraph::query::test_common::GetCypherUnion(storage.Create(true), __VA_ARGS__) #define UNION_ALL(...) memgraph::query::test_common::GetCypherUnion(storage.Create(false), __VA_ARGS__) #define FOREACH(...) memgraph::query::test_common::GetForeach(storage, __VA_ARGS__) diff --git a/tests/unit/query_plan_checker_v2.hpp b/tests/unit/query_plan_checker_v2.hpp index 6a6a12680..1ad1e712a 100644 --- a/tests/unit/query_plan_checker_v2.hpp +++ b/tests/unit/query_plan_checker_v2.hpp @@ -61,6 +61,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { PRE_VISIT(ScanAllByLabelPropertyValue); PRE_VISIT(ScanAllByLabelPropertyRange); PRE_VISIT(ScanAllByLabelProperty); + PRE_VISIT(ScanAllByPrimaryKey); PRE_VISIT(Expand); PRE_VISIT(ExpandVariable); PRE_VISIT(Filter); @@ -270,25 +271,25 @@ using ExpectDistinct = OpChecker; // const std::list &optional_; // }; -// class ExpectScanAllByLabelPropertyValue : public OpChecker { -// public: -// ExpectScanAllByLabelPropertyValue(memgraph::storage::LabelId label, -// const std::pair &prop_pair, -// memgraph::query::Expression *expression) -// : label_(label), property_(prop_pair.second), expression_(expression) {} +class ExpectScanAllByLabelPropertyValue : public OpChecker { + public: + ExpectScanAllByLabelPropertyValue(memgraph::storage::v3::LabelId label, + const std::pair &prop_pair, + memgraph::query::v2::Expression *expression) + : label_(label), property_(prop_pair.second), expression_(expression) {} -// void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { -// EXPECT_EQ(scan_all.label_, label_); -// EXPECT_EQ(scan_all.property_, property_); -// // TODO: Proper expression equality -// EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); -// } + void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + EXPECT_EQ(scan_all.property_, property_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); + } -// private: -// memgraph::storage::LabelId label_; -// memgraph::storage::PropertyId property_; -// memgraph::query::Expression *expression_; -// }; + private: + memgraph::storage::v3::LabelId label_; + memgraph::storage::v3::PropertyId property_; + memgraph::query::v2::Expression *expression_; +}; // class ExpectScanAllByLabelPropertyRange : public OpChecker { // public: @@ -536,12 +537,53 @@ class FakeDistributedDbAccessor { } memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) { - return storage::v3::PropertyId::FromUint(0); + auto find_in_prim_properties = primary_properties_.find(name); + if (find_in_prim_properties != primary_properties_.end()) { + return find_in_prim_properties->second; + } + auto find_in_secondary_properties = secondary_properties_.find(name); + if (find_in_secondary_properties != secondary_properties_.end()) { + return find_in_secondary_properties->second; + } + + MG_ASSERT(false, "The property does not exist as a primary or a secondary property."); + return memgraph::storage::v3::PropertyId::FromUint(0); } - std::vector ExtractPrimaryKey(storage::v3::LabelId label, - std::vector property_filters) { - return std::vector{}; + std::vector> ExtractPrimaryKey( + storage::v3::LabelId label, std::vector property_filters) { + MG_ASSERT(schemas_.contains(label), + "You did not specify the Schema for this label! Use FakeDistributedDbAccessor::CreateSchema(...)."); + + std::vector> pk; + const auto schema = GetSchemaForLabel(label); + + std::vector schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + return pk.size() == schema_properties.size() + ? pk + : std::vector>{}; + } + + std::vector GetSchemaForLabel(storage::v3::LabelId label) { + return schemas_.at(label); + } + + void CreateSchema(const memgraph::storage::v3::LabelId primary_label, + const std::vector &schemas_types) { + MG_ASSERT(!schemas_.contains(primary_label), "You already created the schema for this label!"); + schemas_.emplace(primary_label, schemas_types); } private: @@ -554,6 +596,8 @@ class FakeDistributedDbAccessor { std::unordered_map label_index_; std::vector> label_property_index_; + + std::unordered_map> schemas_; }; } // namespace memgraph::query::v2::plan diff --git a/tests/unit/query_v2_common.hpp b/tests/unit/query_v2_common.hpp new file mode 100644 index 000000000..85905cb22 --- /dev/null +++ b/tests/unit/query_v2_common.hpp @@ -0,0 +1,603 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstStorage storage; // Macros rely on storage being in scope. +/// // PROPERTY_LOOKUP and PROPERTY_PAIR macros +/// // rely on a DbAccessor *reference* named dba. +/// database::GraphDb db; +/// auto dba_ptr = db.Access(); +/// auto &dba = *dba_ptr; +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "query/frontend/ast/pretty_print.hpp" // not sure if that is ok... +#include "query/v2/frontend/ast/ast.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/string.hpp" + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { + +namespace test_common { + +auto ToIntList(const TypedValue &t) { + std::vector list; + for (auto x : t.ValueList()) { + list.push_back(x.ValueInt()); + } + return list; +}; + +auto ToIntMap(const TypedValue &t) { + std::map map; + for (const auto &kv : t.ValueMap()) map.emplace(kv.first, kv.second.ValueInt()); + return map; +}; + +std::string ToString(Expression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +std::string ToString(NamedExpression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, +// so that they can be used to resolve function calls. +struct OrderBy { + std::vector expressions; +}; + +struct Skip { + Expression *expression = nullptr; +}; + +struct Limit { + Expression *expression = nullptr; +}; + +struct OnMatch { + std::vector set; +}; +struct OnCreate { + std::vector set; +}; + +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering = Ordering::ASC) { + order_by.expressions.push_back({ordering, expression}); +} +template +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +template +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + +/// Create PropertyLookup with given name and property. +/// +/// Name is used to create the Identifier which is used for property lookup. +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, const std::string &name, + memgraph::storage::v3::PropertyId property) { + return storage.Create(storage.Create(name), + storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, + memgraph::storage::v3::PropertyId property) { + return storage.Create(expr, storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, const std::string &property) { + return storage.Create(expr, storage.GetPropertyIx(property)); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, const std::string &name, + const std::pair &prop_pair) { + return storage.Create(storage.Create(name), storage.GetPropertyIx(prop_pair.first)); +} + +template +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, + const std::pair &prop_pair) { + return storage.Create(expr, storage.GetPropertyIx(prop_pair.first)); +} + +/// Create an EdgeAtom with given name, direction and edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdge(AstStorage &storage, const std::string &name, EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector &edge_types = {}) { + std::vector types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + return storage.Create(storage.Create(name), EdgeAtom::Type::SINGLE, dir, types); +} + +/// Create a variable length expansion EdgeAtom with given name, direction and +/// edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdgeVariable(AstStorage &storage, const std::string &name, EdgeAtom::Type type = EdgeAtom::Type::DEPTH_FIRST, + EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector &edge_types = {}, Identifier *flambda_inner_edge = nullptr, + Identifier *flambda_inner_node = nullptr, Identifier *wlambda_inner_edge = nullptr, + Identifier *wlambda_inner_node = nullptr, Expression *wlambda_expression = nullptr, + Identifier *total_weight = nullptr) { + std::vector types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + auto r_val = storage.Create(storage.Create(name), type, dir, types); + + r_val->filter_lambda_.inner_edge = + flambda_inner_edge ? flambda_inner_edge : storage.Create(memgraph::utils::RandomString(20)); + r_val->filter_lambda_.inner_node = + flambda_inner_node ? flambda_inner_node : storage.Create(memgraph::utils::RandomString(20)); + + if (type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + r_val->weight_lambda_.inner_edge = + wlambda_inner_edge ? wlambda_inner_edge : storage.Create(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.inner_node = + wlambda_inner_node ? wlambda_inner_node : storage.Create(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.expression = + wlambda_expression ? wlambda_expression : storage.Create(1); + + r_val->total_weight_ = total_weight; + } + + return r_val; +} + +/// Create a NodeAtom with given name and label. +/// +/// Name is used to create the Identifier which is assigned to the node. +auto GetNode(AstStorage &storage, const std::string &name, std::optional label = std::nullopt) { + auto node = storage.Create(storage.Create(name)); + if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); + return node; +} + +/// Create a Pattern with given atoms. +auto GetPattern(AstStorage &storage, std::vector atoms) { + auto pattern = storage.Create(); + pattern->identifier_ = storage.Create(memgraph::utils::RandomString(20), false); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// Create a Pattern with given name and atoms. +auto GetPattern(AstStorage &storage, const std::string &name, std::vector atoms) { + auto pattern = storage.Create(); + pattern->identifier_ = storage.Create(name, true); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// This function fills an AST node which with given patterns. +/// +/// The function is most commonly used to create Match and Create clauses. +template +auto GetWithPatterns(TWithPatterns *with_patterns, std::vector patterns) { + with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); + return with_patterns; +} + +/// Create a query with given clauses. + +auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { + single_query->clauses_.emplace_back(clause); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return single_query; +} +template +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where, T *...clauses) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return GetSingleQuery(single_query, clauses...); +} +template +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where, T *...clauses) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return GetSingleQuery(single_query, clauses...); +} + +template +auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { + single_query->clauses_.emplace_back(clause); + return GetSingleQuery(single_query, clauses...); +} + +auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { + cypher_union->single_query_ = single_query; + return cypher_union; +} + +auto GetQuery(AstStorage &storage, SingleQuery *single_query) { + auto *query = storage.Create(); + query->single_query_ = single_query; + return query; +} + +template +auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_unions) { + auto *query = storage.Create(); + query->single_query_ = single_query; + query->cypher_unions_ = std::vector{cypher_unions...}; + return query; +} + +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name) { + if (name == "*") { + body.all_identifiers = true; + } else { + auto *ident = storage.Create(name); + auto *named_expr = storage.Create(name, ident); + body.named_expressions.emplace_back(named_expr); + } +} +void FillReturnBody(AstStorage &, ReturnBody &body, Limit limit) { body.limit = limit.expression; } +void FillReturnBody(AstStorage &, ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, Expression *expr, NamedExpression *named_expr) { + // This overload supports `RETURN(expr, AS(name))` construct, since + // NamedExpression does not inherit Expression. + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr) { + named_expr->expression_ = storage.Create(name); + body.named_expressions.emplace_back(named_expr); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, Expression *expr, NamedExpression *named_expr, T... rest) { + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr, + T... rest) { + named_expr->expression_ = storage.Create(name); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, T... rest) { + auto *ident = storage.Create(name); + auto *named_expr = storage.Create(name, ident); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} + +/// Create the return clause with given expressions. +/// +/// The supported expression combination of arguments is: +/// +/// (String | NamedExpression | (Expression NamedExpression))+ +/// [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. Taking a String is a shorthand +/// for RETURN(IDENT(string), AS(string), ....). +/// +/// @sa GetWith +template +auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { + auto ret = storage.Create(); + ret->body_.distinct = distinct; + FillReturnBody(storage, ret->body_, exprs...); + return ret; +} + +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn +template +auto GetWith(AstStorage &storage, bool distinct, T... exprs) { + auto with = storage.Create(); + with->body_.distinct = distinct; + FillReturnBody(storage, with->body_, exprs...); + return with; +} + +/// Create the UNWIND clause with given named expression. +auto GetUnwind(AstStorage &storage, NamedExpression *named_expr) { + return storage.Create(named_expr); +} +auto GetUnwind(AstStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + +/// Create the delete clause with given named expressions. +auto GetDelete(AstStorage &storage, std::vector exprs, bool detach = false) { + auto del = storage.Create(); + del->expressions_.insert(del->expressions_.begin(), exprs.begin(), exprs.end()); + del->detach_ = detach; + return del; +} + +/// Create a set property clause for given property lookup and the right hand +/// side expression. +auto GetSet(AstStorage &storage, PropertyLookup *prop_lookup, Expression *expr) { + return storage.Create(prop_lookup, expr); +} + +/// Create a set properties clause for given identifier name and the right hand +/// side expression. +auto GetSet(AstStorage &storage, const std::string &name, Expression *expr, bool update = false) { + return storage.Create(storage.Create(name), expr, update); +} + +/// Create a set labels clause for given identifier name and labels. +auto GetSet(AstStorage &storage, const std::string &name, std::vector label_names) { + std::vector labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create(storage.Create(name), labels); +} + +/// Create a remove property clause for given property lookup +auto GetRemove(AstStorage &storage, PropertyLookup *prop_lookup) { return storage.Create(prop_lookup); } + +/// Create a remove labels clause for given identifier name and labels. +auto GetRemove(AstStorage &storage, const std::string &name, std::vector label_names) { + std::vector labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create(storage.Create(name), labels); +} + +/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate +/// parts. +auto GetMerge(AstStorage &storage, Pattern *pattern, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create(); + merge->pattern_ = pattern; + merge->on_create_ = on_create.set; + return merge; +} +auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create(); + merge->pattern_ = pattern; + merge->on_match_ = on_match.set; + merge->on_create_ = on_create.set; + return merge; +} + +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector arguments = {}) { + auto *call_procedure = storage.Create(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector &clauses) { + return storage.Create(named_expr, clauses); +} + +} // namespace test_common + +} // namespace memgraph::query::v2 + +/// All the following macros implicitly pass `storage` variable to functions. +/// You need to have `AstStorage storage;` somewhere in scope to use them. +/// Refer to function documentation to see what the macro does. +/// +/// Example usage: +/// +/// // Create MATCH (n) -[r]- (m) RETURN m AS new_name +/// AstStorage storage; +/// auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), +/// RETURN(NEXPR("new_name"), IDENT("m"))); +#define NODE(...) memgraph::expr::test_common::GetNode(storage, __VA_ARGS__) +#define EDGE(...) memgraph::expr::test_common::GetEdge(storage, __VA_ARGS__) +#define EDGE_VARIABLE(...) memgraph::expr::test_common::GetEdgeVariable(storage, __VA_ARGS__) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define NAMED_PATTERN(name, ...) memgraph::expr::test_common::GetPattern(storage, name, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create(true), {__VA_ARGS__}) +#define MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) +#define WHERE(expr) storage.Create((expr)) +#define CREATE(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create(), {__VA_ARGS__}) +#define IDENT(...) storage.Create(__VA_ARGS__) +#define LITERAL(val) storage.Create((val)) +#define LIST(...) \ + storage.Create(std::vector{__VA_ARGS__}) +#define MAP(...) \ + storage.Create( \ + std::unordered_map{__VA_ARGS__}) +#define PROPERTY_PAIR(property_name) \ + std::make_pair(property_name, dba.NameToProperty(property_name)) // This one might not be needed at all +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) +#define PROPERTY_LOOKUP(...) memgraph::expr::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) +#define PARAMETER_LOOKUP(token_position) storage.Create((token_position)) +#define NEXPR(name, expr) storage.Create((name), (expr)) +// AS is alternative to NEXPR which does not initialize NamedExpression with +// Expression. It should be used with RETURN or WITH. For example: +// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). +#define AS(name) storage.Create((name)) +#define RETURN(...) memgraph::expr::test_common::GetReturn(storage, false, __VA_ARGS__) +#define WITH(...) memgraph::expr::test_common::GetWith(storage, false, __VA_ARGS__) +#define RETURN_DISTINCT(...) memgraph::expr::test_common::GetReturn(storage, true, __VA_ARGS__) +#define WITH_DISTINCT(...) memgraph::expr::test_common::GetWith(storage, true, __VA_ARGS__) +#define UNWIND(...) memgraph::expr::test_common::GetUnwind(storage, __VA_ARGS__) +#define ORDER_BY(...) memgraph::expr::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + memgraph::expr::test_common::Skip { (expr) } +#define LIMIT(expr) \ + memgraph::expr::test_common::Limit { (expr) } +#define DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}) +#define DETACH_DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}, true) +#define SET(...) memgraph::expr::test_common::GetSet(storage, __VA_ARGS__) +#define REMOVE(...) memgraph::expr::test_common::GetRemove(storage, __VA_ARGS__) +#define MERGE(...) memgraph::expr::test_common::GetMerge(storage, __VA_ARGS__) +#define ON_MATCH(...) \ + memgraph::expr::test_common::OnMatch { \ + std::vector { __VA_ARGS__ } \ + } +#define ON_CREATE(...) \ + memgraph::expr::test_common::OnCreate { \ + std::vector { __VA_ARGS__ } \ + } +#define CREATE_INDEX_ON(label, property) \ + storage.Create(memgraph::query::v2::IndexQuery::Action::CREATE, (label), \ + std::vector{(property)}) +#define QUERY(...) memgraph::expr::test_common::GetQuery(storage, __VA_ARGS__) +#define SINGLE_QUERY(...) memgraph::expr::test_common::GetSingleQuery(storage.Create(), __VA_ARGS__) +#define UNION(...) memgraph::expr::test_common::GetCypherUnion(storage.Create(true), __VA_ARGS__) +#define UNION_ALL(...) memgraph::expr::test_common::GetCypherUnion(storage.Create(false), __VA_ARGS__) +#define FOREACH(...) memgraph::expr::test_common::GetForeach(storage, __VA_ARGS__) +// Various operators +#define NOT(expr) storage.Create((expr)) +#define UPLUS(expr) storage.Create((expr)) +#define UMINUS(expr) storage.Create((expr)) +#define IS_NULL(expr) storage.Create((expr)) +#define ADD(expr1, expr2) storage.Create((expr1), (expr2)) +#define LESS(expr1, expr2) storage.Create((expr1), (expr2)) +#define LESS_EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define GREATER(expr1, expr2) storage.Create((expr1), (expr2)) +#define GREATER_EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define SUM(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::AVG) +#define COLLECT_LIST(expr) \ + storage.Create((expr), nullptr, memgraph::query::v2::Aggregation::Op::COLLECT_LIST) +#define EQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define NEQ(expr1, expr2) storage.Create((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create((expr1), (expr2)) +#define IN_LIST(expr1, expr2) storage.Create((expr1), (expr2)) +#define IF(cond, then, else) storage.Create((cond), (then), (else)) +// Function call +#define FN(function_name, ...) \ + storage.Create(memgraph::utils::ToUpperCase(function_name), \ + std::vector{__VA_ARGS__}) +// List slicing +#define SLICE(list, lower_bound, upper_bound) \ + storage.Create(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define SINGLE(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define ANY(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define NONE(variable, list, where) \ + storage.Create(storage.Create(variable), list, where) +#define REDUCE(accumulator, initializer, variable, list, expr) \ + storage.Create(storage.Create(accumulator), \ + initializer, storage.Create(variable), \ + list, expr) +#define COALESCE(...) \ + storage.Create(std::vector{__VA_ARGS__}) +#define EXTRACT(variable, list, expr) \ + storage.Create(storage.Create(variable), list, expr) +#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ + storage.Create((action), (user), (role), (user_or_role), password, (privileges)) +#define DROP_USER(usernames) storage.Create((usernames)) +#define CALL_PROCEDURE(...) memgraph::query::v2::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_plan_v2.cpp b/tests/unit/query_v2_plan.cpp similarity index 95% rename from tests/unit/query_plan_v2.cpp rename to tests/unit/query_v2_plan.cpp index 009d0ca56..2eb8295db 100644 --- a/tests/unit/query_plan_v2.cpp +++ b/tests/unit/query_v2_plan.cpp @@ -22,14 +22,13 @@ #include #include -#include "query/v2/frontend/ast/ast.hpp" -// #include "query/frontend/semantic/symbol_generator.hpp" #include "expr/semantic/symbol_generator.hpp" #include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/frontend/ast/ast.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/plan/planner.hpp" -#include "query_common.hpp" +#include "query_v2_common.hpp" namespace memgraph::query { ::std::ostream &operator<<(::std::ostream &os, const Symbol &sym) { @@ -39,10 +38,10 @@ namespace memgraph::query { // using namespace memgraph::query::v2::plan; using namespace memgraph::expr::plan; -using memgraph::query::AstStorage; -using memgraph::query::SingleQuery; using memgraph::query::Symbol; using memgraph::query::SymbolGenerator; +using memgraph::query::v2::AstStorage; +using memgraph::query::v2::SingleQuery; using memgraph::query::v2::SymbolTable; using Type = memgraph::query::v2::EdgeAtom::Type; using Direction = memgraph::query::v2::EdgeAtom::Direction; @@ -75,8 +74,8 @@ auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker. } template -auto CheckPlan(memgraph::query::CypherQuery *query, AstStorage &storage, TChecker... checker) { - auto symbol_table = memgraph::query::MakeSymbolTable(query); +auto CheckPlan(memgraph::query::v2::CypherQuery *query, AstStorage &storage, TChecker... checker) { + auto symbol_table = memgraph::expr::MakeSymbolTable(query); FakeDistributedDbAccessor dba; auto planner = MakePlanner(&dba, storage, symbol_table, query); CheckPlan(planner.plan(), symbol_table, checker...); @@ -95,45 +94,91 @@ void DeleteListContent(std::list *list) { TYPED_TEST_CASE(TestPlanner, PlannerTypes); TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) { - FakeDistributedDbAccessor dba; - auto label = dba.Label("prim_label_one"); - auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); - // auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); - auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); - auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); - auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); - dba.SetIndexCount(label, 1); - dba.SetIndexCount(label, prim_prop_one.second, 1); - // dba.SetIndexCount(label, prim_prop_two.second, 1); - dba.SetIndexCount(label, sec_prop_one.second, 1); - dba.SetIndexCount(label, sec_prop_two.second, 1); - dba.SetIndexCount(label, sec_prop_three.second, 1); - memgraph::query::v2::AstStorage storage; + const char *prim_label_name = "prim_label_one"; + // Exact primary key match, one elem as PK. { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second}); + + memgraph::query::v2::AstStorage storage; + memgraph::query::v2::Expression *expected_primary_key; - - // Pray and hope for the best... expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); - - // auto asd1 = NODE("n", "label"); - // auto asd2 = PATTERN(NODE("n", "label")); - // auto asd3 = MATCH_V2(PATTERN(NODE("n", "label"))); - // auto asd4 = WHERE_V2(PROPERTY_LOOKUP("n", prim_prop_one)); - - auto *query = QUERY(SINGLE_QUERY_V2(MATCH_V2(PATTERN(NODE("n", "label"))), - WHERE_V2(PROPERTY_LOOKUP("n", prim_prop_one)), RETURN("n"))); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); auto planner = MakePlanner(&dba, storage, symbol_table, query); CheckPlan(planner.plan(), symbol_table, ExpectScanAllByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // Exact primary key match, two elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); - // // Test MATCH (n :label) -[r]- (m) WHERE n.prop IS NOT NULL RETURN n - // auto *query2 = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", "label"), EDGE("r"), NODE("m"))), - // WHERE(NOT(IS_NULL(PROPERTY_LOOKUP("n", prop)))), RETURN("n"))); - // auto symbol_table = memgraph::query::MakeSymbolTable(query); - // auto planner = MakePlanner(&dba, storage, symbol_table, query); - // // We expect ScanAllByLabelProperty to come instead of ScanAll > Filter. - // CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelProperty(label, prop), ExpectExpand(), - // ExpectProduce()); + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(AND(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1)), + EQ(PROPERTY_LOOKUP("n", prim_prop_two), LITERAL(1)))), + RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // One elem is missing from PK, default to ScanAllByLabelPropertyValue. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, prim_prop_one, IDENT("n")), + ExpectProduce()); } }