diff --git a/src/expr/interpret/frame.hpp b/src/expr/interpret/frame.hpp index 9f4068226..457806680 100644 --- a/src/expr/interpret/frame.hpp +++ b/src/expr/interpret/frame.hpp @@ -23,9 +23,9 @@ namespace memgraph::expr { class Frame { public: /// Create a Frame of given size backed by a utils::NewDeleteResource() - explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } + explicit Frame(size_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } - Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } + Frame(size_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; } const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; } @@ -44,9 +44,9 @@ class Frame { class FrameWithValidity final : public Frame { public: - explicit FrameWithValidity(int64_t size) : Frame(size), is_valid_(false) {} + explicit FrameWithValidity(size_t size) : Frame(size), is_valid_(false) {} - FrameWithValidity(int64_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {} + FrameWithValidity(size_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {} bool IsValid() const noexcept { return is_valid_; } void MakeValid() noexcept { is_valid_ = true; } diff --git a/src/query/v2/context.hpp b/src/query/v2/context.hpp index cb30a9ced..58f5ada97 100644 --- a/src/query/v2/context.hpp +++ b/src/query/v2/context.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/query/v2/plan/cost_estimator.hpp b/src/query/v2/plan/cost_estimator.hpp index 8cba2505f..250274a73 100644 --- a/src/query/v2/plan/cost_estimator.hpp +++ b/src/query/v2/plan/cost_estimator.hpp @@ -149,8 +149,6 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { return true; } - // TODO: Cost estimate ScanAllById? - // For the given op first increments the cardinality and then cost. #define POST_VISIT_CARD_FIRST(NAME) \ bool PostVisit(NAME &) override { \ diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 15f34fe34..4b5999138 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -34,6 +34,7 @@ #include "query/v2/bindings/eval.hpp" #include "query/v2/bindings/symbol_table.hpp" #include "query/v2/context.hpp" +#include "query/v2/conversions.hpp" #include "query/v2/db_accessor.hpp" #include "query/v2/exceptions.hpp" #include "query/v2/frontend/ast/ast.hpp" @@ -93,7 +94,7 @@ extern const Event ScanAllByLabelOperator; extern const Event ScanAllByLabelPropertyRangeOperator; extern const Event ScanAllByLabelPropertyValueOperator; extern const Event ScanAllByLabelPropertyOperator; -extern const Event ScanAllByIdOperator; +extern const Event ScanByPrimaryKeyOperator; extern const Event ExpandOperator; extern const Event ExpandVariableOperator; extern const Event ConstructNamedPathOperator; @@ -572,6 +573,88 @@ class DistributedScanAllAndFilterCursor : public Cursor { ValidFramesConsumer::Iterator own_frames_it_; }; +class DistributedScanByPrimaryKeyCursor : public Cursor { + public: + explicit DistributedScanByPrimaryKeyCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name, + storage::v3::LabelId label, + std::optional<std::vector<Expression *>> filter_expressions, + std::vector<Expression *> primary_key) + : output_symbol_(output_symbol), + input_cursor_(std::move(input_cursor)), + op_name_(op_name), + label_(label), + filter_expressions_(filter_expressions), + primary_key_(primary_key) {} + + enum class State : int8_t { INITIALIZING, COMPLETED }; + + using VertexAccessor = accessors::VertexAccessor; + + std::optional<VertexAccessor> MakeRequestSingleFrame(Frame &frame, RequestRouterInterface &request_router, + ExecutionContext &context) { + // Evaluate the expressions that hold the PrimaryKey. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router, + storage::v3::View::NEW); + + std::vector<msgs::Value> pk; + for (auto *primary_property : primary_key_) { + pk.push_back(TypedValueToValue(primary_property->Accept(evaluator))); + } + + msgs::Label label = {.id = msgs::LabelId::FromUint(label_.AsUint())}; + + msgs::GetPropertiesRequest req = {.vertex_ids = {std::make_pair(label, pk)}}; + auto get_prop_result = std::invoke([&context, &request_router, &req]() mutable { + SCOPED_REQUEST_WAIT_PROFILE; + return request_router.GetProperties(req); + }); + MG_ASSERT(get_prop_result.size() <= 1); + + if (get_prop_result.empty()) { + return std::nullopt; + } + auto properties = get_prop_result[0].props; + // TODO (gvolfing) figure out labels when relevant. + msgs::Vertex vertex = {.id = get_prop_result[0].vertex, .labels = {}}; + + return VertexAccessor(vertex, properties, &request_router); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + + if (MustAbort(context)) { + throw HintedAbortError(); + } + + while (input_cursor_->Pull(frame, context)) { + auto &request_router = *context.request_router; + auto vertex = MakeRequestSingleFrame(frame, request_router, context); + if (vertex) { + frame[output_symbol_] = TypedValue(std::move(*vertex)); + return true; + } + } + return false; + } + + void PullMultiple(MultiFrame & /*input_multi_frame*/, ExecutionContext & /*context*/) override { + throw utils::NotYetImplemented("Multiframe version of ScanByPrimaryKey is yet to be implemented."); + }; + + void Reset() override { input_cursor_->Reset(); } + + void Shutdown() override { input_cursor_->Shutdown(); } + + private: + const Symbol output_symbol_; + const UniqueCursorPtr input_cursor_; + const char *op_name_; + storage::v3::LabelId label_; + std::optional<std::vector<Expression *>> filter_expressions_; + std::vector<Expression *> primary_key_; +}; + ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view) : input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {} @@ -666,22 +749,21 @@ UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) c throw QueryRuntimeException("ScanAllByLabelProperty is not supported"); } -ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, - storage::v3::View view) - : ScanAll(input, output_symbol, view), expression_(expression) { - MG_ASSERT(expression); +ScanByPrimaryKey::ScanByPrimaryKey(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::v3::LabelId label, std::vector<query::v2::Expression *> primary_key, + storage::v3::View view) + : ScanAll(input, output_symbol, view), label_(label), primary_key_(primary_key) { + MG_ASSERT(primary_key.front()); } -ACCEPT_WITH_INPUT(ScanAllById) +ACCEPT_WITH_INPUT(ScanByPrimaryKey) -UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { - EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator); - // TODO Reimplement when we have reliable conversion between hash value and pk - auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> { - return std::nullopt; - }; - return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), - std::move(vertices), "ScanAllById"); +UniqueCursorPtr ScanByPrimaryKey::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanByPrimaryKeyOperator); + + return MakeUniqueCursorPtr<DistributedScanByPrimaryKeyCursor>(mem, output_symbol_, input_->MakeCursor(mem), + "ScanByPrimaryKey", label_, + std::nullopt /*filter_expressions*/, primary_key_); } Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, diff --git a/src/query/v2/plan/operator.lcp b/src/query/v2/plan/operator.lcp index efa0d5df0..e2d3eca56 100644 --- a/src/query/v2/plan/operator.lcp +++ b/src/query/v2/plan/operator.lcp @@ -113,7 +113,7 @@ class ScanAllByLabel; class ScanAllByLabelPropertyRange; class ScanAllByLabelPropertyValue; class ScanAllByLabelProperty; -class ScanAllById; +class ScanByPrimaryKey; class Expand; class ExpandVariable; class ConstructNamedPath; @@ -144,7 +144,7 @@ class Foreach; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue, - ScanAllByLabelProperty, ScanAllById, + ScanAllByLabelProperty, ScanByPrimaryKey, Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, @@ -845,19 +845,21 @@ given label and property. (:serialize (:slk)) (:clone)) - - -(lcp:define-class scan-all-by-id (scan-all) - ((expression "Expression *" :scope :public +(lcp:define-class scan-by-primary-key (scan-all) + ((label "::storage::v3::LabelId" :scope :public) + (primary-key "std::vector<Expression*>" :scope :public) + (expression "Expression *" :scope :public :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression"))) (:documentation - "ScanAll producing a single node with ID equal to evaluated expression") + "ScanAll producing a single node with specified by the label and primary key") (:public #>cpp - ScanAllById() {} - ScanAllById(const std::shared_ptr<LogicalOperator> &input, - Symbol output_symbol, Expression *expression, + ScanByPrimaryKey() {} + ScanByPrimaryKey(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, + storage::v3::LabelId label, + std::vector<query::v2::Expression*> primary_key, storage::v3::View view = storage::v3::View::OLD); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; diff --git a/src/query/v2/plan/pretty_print.cpp b/src/query/v2/plan/pretty_print.cpp index bc30a3890..7eb1dd5a9 100644 --- a/src/query/v2/plan/pretty_print.cpp +++ b/src/query/v2/plan/pretty_print.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -86,10 +86,10 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) { return true; } -bool PlanPrinter::PreVisit(ScanAllById &op) { +bool PlanPrinter::PreVisit(query::v2::plan::ScanByPrimaryKey &op) { WithPrintLn([&](auto &out) { - out << "* ScanAllById" - << " (" << op.output_symbol_.name() << ")"; + out << "* ScanByPrimaryKey" + << " (" << op.output_symbol_.name() << " :" << request_router_->LabelToName(op.label_) << ")"; }); return true; } @@ -487,12 +487,15 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) { return false; } -bool PlanToJsonVisitor::PreVisit(ScanAllById &op) { +bool PlanToJsonVisitor::PreVisit(ScanByPrimaryKey &op) { json self; - self["name"] = "ScanAllById"; + self["name"] = "ScanByPrimaryKey"; + self["label"] = ToJson(op.label_, *request_router_); self["output_symbol"] = ToJson(op.output_symbol_); + op.input_->Accept(*this); self["input"] = PopOutput(); + output_ = std::move(self); return false; } diff --git a/src/query/v2/plan/pretty_print.hpp b/src/query/v2/plan/pretty_print.hpp index 4094d7c81..31ff17b18 100644 --- a/src/query/v2/plan/pretty_print.hpp +++ b/src/query/v2/plan/pretty_print.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -67,7 +67,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; @@ -194,7 +194,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Produce &) override; bool PreVisit(Accumulate &) override; diff --git a/src/query/v2/plan/read_write_type_checker.cpp b/src/query/v2/plan/read_write_type_checker.cpp index 6cc38cedf..28c1e8d0a 100644 --- a/src/query/v2/plan/read_write_type_checker.cpp +++ b/src/query/v2/plan/read_write_type_checker.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,10 +11,12 @@ #include "query/v2/plan/read_write_type_checker.hpp" -#define PRE_VISIT(TOp, RWType, continue_visiting) \ - bool ReadWriteTypeChecker::PreVisit(TOp &op) { \ - UpdateType(RWType); \ - return continue_visiting; \ +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define PRE_VISIT(TOp, RWType, continue_visiting) \ + /*NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \ + bool ReadWriteTypeChecker::PreVisit(TOp & /*op*/) { \ + UpdateType(RWType); \ + return continue_visiting; \ } namespace memgraph::query::v2::plan { @@ -35,7 +37,7 @@ PRE_VISIT(ScanAllByLabel, RWType::R, true) PRE_VISIT(ScanAllByLabelPropertyRange, RWType::R, true) PRE_VISIT(ScanAllByLabelPropertyValue, RWType::R, true) PRE_VISIT(ScanAllByLabelProperty, RWType::R, true) -PRE_VISIT(ScanAllById, RWType::R, true) +PRE_VISIT(ScanByPrimaryKey, RWType::R, true) PRE_VISIT(Expand, RWType::R, true) PRE_VISIT(ExpandVariable, RWType::R, true) diff --git a/src/query/v2/plan/read_write_type_checker.hpp b/src/query/v2/plan/read_write_type_checker.hpp index a3c2f1a46..626cf7af3 100644 --- a/src/query/v2/plan/read_write_type_checker.hpp +++ b/src/query/v2/plan/read_write_type_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -59,7 +59,7 @@ class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; diff --git a/src/query/v2/plan/rewrite/index_lookup.hpp b/src/query/v2/plan/rewrite/index_lookup.hpp index 57ddba54e..17996d952 100644 --- a/src/query/v2/plan/rewrite/index_lookup.hpp +++ b/src/query/v2/plan/rewrite/index_lookup.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,8 +25,10 @@ #include <gflags/gflags.h> +#include "query/v2/frontend/ast/ast.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/plan/preprocess.hpp" +#include "storage/v3/id_types.hpp" DECLARE_int64(query_vertex_count_to_expand_existing); @@ -271,11 +273,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } - bool PreVisit(ScanAllById &op) override { + bool PreVisit(ScanByPrimaryKey &op) override { prev_ops_.push_back(&op); return true; } - bool PostVisit(ScanAllById &) override { + + bool PostVisit(ScanByPrimaryKey & /*unused*/) override { prev_ops_.pop_back(); return true; } @@ -487,6 +490,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { storage::v3::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); } + void EraseLabelFilters(const memgraph::query::v2::Symbol &node_symbol, memgraph::query::v2::LabelIx prim_label) { + std::vector<query::v2::Expression *> removed_expressions; + filters_.EraseLabelFilter(node_symbol, prim_label, &removed_expressions); + filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + } + std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) { MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels."); std::optional<LabelIx> best_label; @@ -559,31 +568,81 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { const auto &view = scan.view_; const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_); std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end()); - auto are_bound = [&bound_symbols](const auto &used_symbols) { - for (const auto &used_symbol : used_symbols) { - if (!utils::Contains(bound_symbols, used_symbol)) { - return false; - } - } - return true; - }; - // First, try to see if we can find a vertex by ID. - if (!max_vertex_count || *max_vertex_count >= 1) { - for (const auto &filter : filters_.IdFilters(node_symbol)) { - if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue; - auto *value = filter.id_filter->value_; - filter_exprs_for_removal_.insert(filter.expression); - filters_.EraseFilter(filter); - return std::make_unique<ScanAllById>(input, node_symbol, value, view); - } - } - // Now try to see if we can use label+property index. If not, try to use - // just the label index. + + // Try to see if we can use label + primary-key or label + property index. + // If not, try to use just the label index. const auto labels = filters_.FilteredLabels(node_symbol); if (labels.empty()) { // Without labels, we cannot generate any indexed ScanAll. return nullptr; } + + // First, try to see if we can find a vertex based on the possibly + // supplied primary key. + auto property_filters = filters_.PropertyFilters(node_symbol); + query::v2::LabelIx prim_label; + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> primary_key; + + auto extract_primary_key = [this](storage::v3::LabelId label, + std::vector<query::v2::plan::FilterInfo> property_filters) + -> std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> { + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk_temp; + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk; + std::vector<memgraph::storage::v3::SchemaProperty> schema = db_->GetSchemaForLabel(label); + + std::vector<storage::v3::PropertyId> schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem.property_id; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = db_->NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk_temp.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + // Make sure pk is in the same order as schema_properties. + for (const auto &schema_prop : schema_properties) { + for (auto &pk_temp_prop : pk_temp) { + const auto &property_id = db_->NameToProperty(pk_temp_prop.second.property_filter->property_.name); + if (schema_prop == property_id) { + pk.push_back(pk_temp_prop); + } + } + } + MG_ASSERT(pk.size() == pk_temp.size(), + "The two vectors should represent the same primary key with a possibly different order of contained " + "elements."); + + return pk.size() == schema_properties.size() + ? pk + : std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>>{}; + }; + + if (!property_filters.empty()) { + for (const auto &label : labels) { + if (db_->PrimaryLabelExists(GetLabel(label))) { + prim_label = label; + primary_key = extract_primary_key(GetLabel(prim_label), property_filters); + break; + } + } + if (!primary_key.empty()) { + // Mark the expressions so they won't be used for an additional, unnecessary filter. + for (const auto &primary_property : primary_key) { + filter_exprs_for_removal_.insert(primary_property.first); + filters_.EraseFilter(primary_property.second); + } + EraseLabelFilters(node_symbol, prim_label); + std::vector<query::v2::Expression *> pk_expressions; + std::transform(primary_key.begin(), primary_key.end(), std::back_inserter(pk_expressions), + [](const auto &exp) { return exp.second.property_filter->value_; }); + return std::make_unique<ScanByPrimaryKey>(input, node_symbol, GetLabel(prim_label), pk_expressions); + } + } + auto found_index = FindBestLabelPropertyIndex(node_symbol, bound_symbols); if (found_index && // Use label+property index if we satisfy max_vertex_count. @@ -597,9 +656,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { filter_exprs_for_removal_.insert(found_index->filter.expression); } filters_.EraseFilter(found_index->filter); - std::vector<Expression *> removed_expressions; - filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions); - filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + EraseLabelFilters(node_symbol, found_index->label); if (prop_filter.lower_bound_ || prop_filter.upper_bound_) { return std::make_unique<ScanAllByLabelPropertyRange>( input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), diff --git a/src/query/v2/plan/vertex_count_cache.hpp b/src/query/v2/plan/vertex_count_cache.hpp index e68ce1220..ae6cdad31 100644 --- a/src/query/v2/plan/vertex_count_cache.hpp +++ b/src/query/v2/plan/vertex_count_cache.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,14 +12,17 @@ /// @file #pragma once +#include <iterator> #include <optional> #include "query/v2/bindings/typed_value.hpp" +#include "query/v2/plan/preprocess.hpp" #include "query/v2/request_router.hpp" #include "storage/v3/conversions.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" #include "utils/bound.hpp" +#include "utils/exceptions.hpp" #include "utils/fnv.hpp" namespace memgraph::query::v2::plan { @@ -52,11 +55,16 @@ class VertexCountCache { return 1; } - // For now return true if label is primary label - bool LabelIndexExists(storage::v3::LabelId label) { return request_router_->IsPrimaryLabel(label); } + bool LabelIndexExists(storage::v3::LabelId label) { return PrimaryLabelExists(label); } + + bool PrimaryLabelExists(storage::v3::LabelId label) { return request_router_->IsPrimaryLabel(label); } bool LabelPropertyIndexExists(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/) { return false; } + const std::vector<memgraph::storage::v3::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) { + return request_router_->GetSchemaForLabel(label); + } + RequestRouterInterface *request_router_; }; diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index 96d51b05f..c175ba413 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -114,6 +114,7 @@ class RequestRouterInterface { virtual std::optional<storage::v3::LabelId> MaybeNameToLabel(const std::string &name) const = 0; virtual bool IsPrimaryLabel(storage::v3::LabelId label) const = 0; virtual bool IsPrimaryKey(storage::v3::LabelId primary_label, storage::v3::PropertyId property) const = 0; + virtual const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) const = 0; }; // TODO(kostasrim)rename this class template @@ -232,12 +233,17 @@ class RequestRouter : public RequestRouterInterface { }) != schema_it->second.end(); } + const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) const override { + return shards_map_.schemas.at(label); + } + bool IsPrimaryLabel(storage::v3::LabelId label) const override { return shards_map_.label_spaces.contains(label); } // TODO(kostasrim) Simplify return result std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override { // create requests - std::vector<ShardRequestState<msgs::ScanVerticesRequest>> requests_to_be_sent = RequestsForScanVertices(label); + auto requests_to_be_sent = RequestsForScanVertices(label); + spdlog::trace("created {} ScanVertices requests", requests_to_be_sent.size()); // begin all requests in parallel @@ -360,6 +366,7 @@ class RequestRouter : public RequestRouterInterface { } std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest requests) override { + requests.transaction_id = transaction_id_; // create requests std::vector<ShardRequestState<msgs::GetPropertiesRequest>> requests_to_be_sent = RequestsForGetProperties(std::move(requests)); diff --git a/src/storage/v3/shard_rsm.cpp b/src/storage/v3/shard_rsm.cpp index 67a58e5cc..881796a70 100644 --- a/src/storage/v3/shard_rsm.cpp +++ b/src/storage/v3/shard_rsm.cpp @@ -544,6 +544,7 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest &&req) { } const auto *schema = shard_->GetSchema(shard_->PrimaryLabel()); MG_ASSERT(schema); + return CollectAllPropertiesFromAccessor(v_acc, view, *schema); } diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index 634b6eae2..2928027ee 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,6 +25,7 @@ M(ScanAllByLabelPropertyValueOperator, "Number of times ScanAllByLabelPropertyValue operator was used.") \ M(ScanAllByLabelPropertyOperator, "Number of times ScanAllByLabelProperty operator was used.") \ M(ScanAllByIdOperator, "Number of times ScanAllById operator was used.") \ + M(ScanByPrimaryKeyOperator, "Number of times ScanByPrimaryKey operator was used.") \ M(ExpandOperator, "Number of times Expand operator was used.") \ M(ExpandVariableOperator, "Number of times ExpandVariable operator was used.") \ M(ConstructNamedPathOperator, "Number of times ConstructNamedPath operator was used.") \ diff --git a/tests/mgbench/client.cpp b/tests/mgbench/client.cpp index e4b63d477..000c199fa 100644 --- a/tests/mgbench/client.cpp +++ b/tests/mgbench/client.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/simulation/request_router.cpp b/tests/simulation/request_router.cpp index 4248e7876..037674b66 100644 --- a/tests/simulation/request_router.cpp +++ b/tests/simulation/request_router.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/simulation/test_cluster.hpp b/tests/simulation/test_cluster.hpp index 2e8bdf92f..791b45faa 100644 --- a/tests/simulation/test_cluster.hpp +++ b/tests/simulation/test_cluster.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e7efb7473..ae747385b 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -93,6 +93,9 @@ target_link_libraries(${test_prefix}query_expression_evaluator mg-query) add_unit_test(query_plan.cpp) target_link_libraries(${test_prefix}query_plan mg-query) +add_unit_test(query_v2_plan.cpp) +target_link_libraries(${test_prefix}query_v2_plan mg-query-v2) + add_unit_test(query_plan_accumulate_aggregate.cpp) target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) diff --git a/tests/unit/high_density_shard_create_scan.cpp b/tests/unit/high_density_shard_create_scan.cpp index 9fabf6ccc..4a90b98b2 100644 --- a/tests/unit/high_density_shard_create_scan.cpp +++ b/tests/unit/high_density_shard_create_scan.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/machine_manager.cpp b/tests/unit/machine_manager.cpp index 74b7d3863..6cc6a3ff1 100644 --- a/tests/unit/machine_manager.cpp +++ b/tests/unit/machine_manager.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/mock_helpers.hpp b/tests/unit/mock_helpers.hpp index c522b8602..b7d764ac8 100644 --- a/tests/unit/mock_helpers.hpp +++ b/tests/unit/mock_helpers.hpp @@ -42,6 +42,7 @@ class MockedRequestRouter : public RequestRouterInterface { MOCK_METHOD(std::optional<storage::v3::LabelId>, MaybeNameToLabel, (const std::string &), (const)); MOCK_METHOD(bool, IsPrimaryLabel, (storage::v3::LabelId), (const)); MOCK_METHOD(bool, IsPrimaryKey, (storage::v3::LabelId, storage::v3::PropertyId), (const)); + MOCK_METHOD(const std::vector<coordinator::SchemaProperty> &, GetSchemaForLabel, (storage::v3::LabelId), (const)); }; class MockedLogicalOperator : public plan::LogicalOperator { diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 903cacf12..784df5e21 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -501,6 +501,8 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec storage.Create<memgraph::query::MapLiteral>( \ std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>{__VA_ARGS__}) #define PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) #define PROPERTY_LOOKUP(...) memgraph::query::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) #define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::ParameterLookup>((token_position)) #define NEXPR(name, expr) storage.Create<memgraph::query::NamedExpression>((name), (expr)) diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 335b6ab2b..7907f167f 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,6 +17,7 @@ #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" #include "query/plan/preprocess.hpp" +#include "query/v2/plan/operator.hpp" namespace memgraph::query::plan { @@ -90,7 +91,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { } PRE_VISIT(Unwind); PRE_VISIT(Distinct); - + bool PreVisit(Foreach &op) override { CheckOp(op); return false; @@ -336,6 +337,35 @@ class ExpectScanAllByLabelProperty : public OpChecker<ScanAllByLabelProperty> { memgraph::storage::PropertyId property_; }; +class ExpectScanByPrimaryKey : public OpChecker<v2::plan::ScanByPrimaryKey> { + public: + ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector<Expression *> &properties) + : label_(label), properties_(properties) {} + + void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + + bool primary_property_match = true; + for (const auto &expected_prop : properties_) { + bool has_match = false; + for (const auto &prop : scan_all.primary_key_) { + if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { + has_match = true; + } + } + if (!has_match) { + primary_property_match = false; + } + } + + EXPECT_TRUE(primary_property_match); + } + + private: + memgraph::storage::v3::LabelId label_; + std::vector<Expression *> properties_; +}; + class ExpectCartesian : public OpChecker<Cartesian> { public: ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, diff --git a/tests/unit/query_plan_checker_v2.hpp b/tests/unit/query_plan_checker_v2.hpp new file mode 100644 index 000000000..014f1c2bb --- /dev/null +++ b/tests/unit/query_plan_checker_v2.hpp @@ -0,0 +1,452 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "query/frontend/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/planner.hpp" +#include "query/v2/plan/preprocess.hpp" +#include "utils/exceptions.hpp" + +namespace memgraph::query::v2::plan { + +class BaseOpChecker { + public: + virtual ~BaseOpChecker() {} + + virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; +}; + +class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { + public: + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + PlanChecker(const std::list<std::unique_ptr<BaseOpChecker>> &checkers, const SymbolTable &symbol_table) + : symbol_table_(symbol_table) { + for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); + } + + PlanChecker(const std::list<BaseOpChecker *> &checkers, const SymbolTable &symbol_table) + : checkers_(checkers), symbol_table_(symbol_table) {} + +#define PRE_VISIT(TOp) \ + bool PreVisit(TOp &op) override { \ + CheckOp(op); \ + return true; \ + } + +#define VISIT(TOp) \ + bool Visit(TOp &op) override { \ + CheckOp(op); \ + return true; \ + } + + PRE_VISIT(CreateNode); + PRE_VISIT(CreateExpand); + PRE_VISIT(Delete); + PRE_VISIT(ScanAll); + PRE_VISIT(ScanAllByLabel); + PRE_VISIT(ScanAllByLabelPropertyValue); + PRE_VISIT(ScanAllByLabelPropertyRange); + PRE_VISIT(ScanAllByLabelProperty); + PRE_VISIT(ScanByPrimaryKey); + PRE_VISIT(Expand); + PRE_VISIT(ExpandVariable); + PRE_VISIT(Filter); + PRE_VISIT(ConstructNamedPath); + PRE_VISIT(Produce); + PRE_VISIT(SetProperty); + PRE_VISIT(SetProperties); + PRE_VISIT(SetLabels); + PRE_VISIT(RemoveProperty); + PRE_VISIT(RemoveLabels); + PRE_VISIT(EdgeUniquenessFilter); + PRE_VISIT(Accumulate); + PRE_VISIT(Aggregate); + PRE_VISIT(Skip); + PRE_VISIT(Limit); + PRE_VISIT(OrderBy); + bool PreVisit(Merge &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } + bool PreVisit(Optional &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } + PRE_VISIT(Unwind); + PRE_VISIT(Distinct); + + bool PreVisit(Foreach &op) override { + CheckOp(op); + return false; + } + + bool Visit(Once &) override { + // Ignore checking Once, it is implicitly at the end. + return true; + } + + bool PreVisit(Cartesian &op) override { + CheckOp(op); + return false; + } + + PRE_VISIT(CallProcedure); + +#undef PRE_VISIT +#undef VISIT + + void CheckOp(LogicalOperator &op) { + ASSERT_FALSE(checkers_.empty()); + checkers_.back()->CheckOp(op, symbol_table_); + checkers_.pop_back(); + } + + std::list<BaseOpChecker *> checkers_; + const SymbolTable &symbol_table_; +}; + +template <class TOp> +class OpChecker : public BaseOpChecker { + public: + void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override { + auto *expected_op = dynamic_cast<TOp *>(&op); + ASSERT_TRUE(expected_op) << "op is '" << op.GetTypeInfo().name << "' expected '" << TOp::kType.name << "'!"; + ExpectOp(*expected_op, symbol_table); + } + + virtual void ExpectOp(TOp &, const SymbolTable &) {} +}; + +using ExpectCreateNode = OpChecker<CreateNode>; +using ExpectCreateExpand = OpChecker<CreateExpand>; +using ExpectDelete = OpChecker<Delete>; +using ExpectScanAll = OpChecker<ScanAll>; +using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>; +using ExpectExpand = OpChecker<Expand>; +using ExpectFilter = OpChecker<Filter>; +using ExpectConstructNamedPath = OpChecker<ConstructNamedPath>; +using ExpectProduce = OpChecker<Produce>; +using ExpectSetProperty = OpChecker<SetProperty>; +using ExpectSetProperties = OpChecker<SetProperties>; +using ExpectSetLabels = OpChecker<SetLabels>; +using ExpectRemoveProperty = OpChecker<RemoveProperty>; +using ExpectRemoveLabels = OpChecker<RemoveLabels>; +using ExpectEdgeUniquenessFilter = OpChecker<EdgeUniquenessFilter>; +using ExpectSkip = OpChecker<Skip>; +using ExpectLimit = OpChecker<Limit>; +using ExpectOrderBy = OpChecker<OrderBy>; +using ExpectUnwind = OpChecker<Unwind>; +using ExpectDistinct = OpChecker<Distinct>; + +class ExpectScanAllByLabelPropertyValue : public OpChecker<ScanAllByLabelPropertyValue> { + public: + ExpectScanAllByLabelPropertyValue(memgraph::storage::v3::LabelId label, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair, + memgraph::query::v2::Expression *expression) + : label_(label), property_(prop_pair.second), expression_(expression) {} + + void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + EXPECT_EQ(scan_all.property_, property_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); + } + + private: + memgraph::storage::v3::LabelId label_; + memgraph::storage::v3::PropertyId property_; + memgraph::query::v2::Expression *expression_; +}; + +class ExpectScanByPrimaryKey : public OpChecker<v2::plan::ScanByPrimaryKey> { + public: + ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector<Expression *> &properties) + : label_(label), properties_(properties) {} + + void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + + bool primary_property_match = true; + for (const auto &expected_prop : properties_) { + bool has_match = false; + for (const auto &prop : scan_all.primary_key_) { + if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { + has_match = true; + } + } + if (!has_match) { + primary_property_match = false; + } + } + + EXPECT_TRUE(primary_property_match); + } + + private: + memgraph::storage::v3::LabelId label_; + std::vector<Expression *> properties_; +}; + +class ExpectCartesian : public OpChecker<Cartesian> { + public: + ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, + const std::list<std::unique_ptr<BaseOpChecker>> &right) + : left_(left), right_(right) {} + + void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { + ASSERT_TRUE(op.left_op_); + PlanChecker left_checker(left_, symbol_table); + op.left_op_->Accept(left_checker); + ASSERT_TRUE(op.right_op_); + PlanChecker right_checker(right_, symbol_table); + op.right_op_->Accept(right_checker); + } + + private: + const std::list<std::unique_ptr<BaseOpChecker>> &left_; + const std::list<std::unique_ptr<BaseOpChecker>> &right_; +}; + +class ExpectCallProcedure : public OpChecker<CallProcedure> { + public: + ExpectCallProcedure(const std::string &name, const std::vector<memgraph::query::Expression *> &args, + const std::vector<std::string> &fields, const std::vector<Symbol> &result_syms) + : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} + + void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { + EXPECT_EQ(op.procedure_name_, name_); + EXPECT_EQ(op.arguments_.size(), args_.size()); + for (size_t i = 0; i < args_.size(); ++i) { + const auto *op_arg = op.arguments_[i]; + const auto *expected_arg = args_[i]; + // TODO: Proper expression equality + EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); + } + EXPECT_EQ(op.result_fields_, fields_); + EXPECT_EQ(op.result_symbols_, result_syms_); + } + + private: + std::string name_; + std::vector<memgraph::query::Expression *> args_; + std::vector<std::string> fields_; + std::vector<Symbol> result_syms_; +}; + +template <class T> +std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg) { + std::list<std::unique_ptr<BaseOpChecker>> l; + l.emplace_back(std::make_unique<T>(arg)); + return l; +} + +template <class T, class... Rest> +std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg, Rest &&...rest) { + auto l = MakeCheckers(std::forward<Rest>(rest)...); + l.emplace_front(std::make_unique<T>(arg)); + return std::move(l); +} + +template <class TPlanner, class TDbAccessor> +TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) { + auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba); + auto query_parts = CollectQueryParts(symbol_table, storage, query); + auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; + return TPlanner(single_query_parts, planning_context); +} + +class FakeDistributedDbAccessor { + public: + int64_t VerticesCount(memgraph::storage::v3::LabelId label) const { + auto found = label_index_.find(label); + if (found != label_index_.end()) return found->second; + return 0; + } + + int64_t VerticesCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + return std::get<2>(index); + } + } + return 0; + } + + bool LabelIndexExists(memgraph::storage::v3::LabelId label) const { + throw utils::NotYetImplemented("Label indicies are yet to be implemented."); + } + + bool LabelPropertyIndexExists(memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property) const { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + return true; + } + } + return false; + } + + bool PrimaryLabelExists(storage::v3::LabelId label) { return label_index_.find(label) != label_index_.end(); } + + void SetIndexCount(memgraph::storage::v3::LabelId label, int64_t count) { label_index_[label] = count; } + + void SetIndexCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property, int64_t count) { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + std::get<2>(index) = count; + return; + } + } + label_property_index_.emplace_back(label, property, count); + } + + memgraph::storage::v3::LabelId NameToLabel(const std::string &name) { + auto found = primary_labels_.find(name); + if (found != primary_labels_.end()) return found->second; + return primary_labels_.emplace(name, memgraph::storage::v3::LabelId::FromUint(primary_labels_.size())) + .first->second; + } + + memgraph::storage::v3::LabelId Label(const std::string &name) { return NameToLabel(name); } + + memgraph::storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) { + auto found = edge_types_.find(name); + if (found != edge_types_.end()) return found->second; + return edge_types_.emplace(name, memgraph::storage::v3::EdgeTypeId::FromUint(edge_types_.size())).first->second; + } + + memgraph::storage::v3::PropertyId NameToPrimaryProperty(const std::string &name) { + auto found = primary_properties_.find(name); + if (found != primary_properties_.end()) return found->second; + return primary_properties_.emplace(name, memgraph::storage::v3::PropertyId::FromUint(primary_properties_.size())) + .first->second; + } + + memgraph::storage::v3::PropertyId NameToSecondaryProperty(const std::string &name) { + auto found = secondary_properties_.find(name); + if (found != secondary_properties_.end()) return found->second; + return secondary_properties_ + .emplace(name, memgraph::storage::v3::PropertyId::FromUint(secondary_properties_.size())) + .first->second; + } + + memgraph::storage::v3::PropertyId PrimaryProperty(const std::string &name) { return NameToPrimaryProperty(name); } + memgraph::storage::v3::PropertyId SecondaryProperty(const std::string &name) { return NameToSecondaryProperty(name); } + + std::string PrimaryPropertyToName(memgraph::storage::v3::PropertyId property) const { + for (const auto &kv : primary_properties_) { + if (kv.second == property) return kv.first; + } + LOG_FATAL("Unable to find primary property name"); + } + + std::string SecondaryPropertyToName(memgraph::storage::v3::PropertyId property) const { + for (const auto &kv : secondary_properties_) { + if (kv.second == property) return kv.first; + } + LOG_FATAL("Unable to find secondary property name"); + } + + std::string PrimaryPropertyName(memgraph::storage::v3::PropertyId property) const { + return PrimaryPropertyToName(property); + } + std::string SecondaryPropertyName(memgraph::storage::v3::PropertyId property) const { + return SecondaryPropertyToName(property); + } + + memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) { + auto find_in_prim_properties = primary_properties_.find(name); + if (find_in_prim_properties != primary_properties_.end()) { + return find_in_prim_properties->second; + } + auto find_in_secondary_properties = secondary_properties_.find(name); + if (find_in_secondary_properties != secondary_properties_.end()) { + return find_in_secondary_properties->second; + } + + LOG_FATAL("The property does not exist as a primary or a secondary property."); + return memgraph::storage::v3::PropertyId::FromUint(0); + } + + std::vector<memgraph::storage::v3::SchemaProperty> GetSchemaForLabel(storage::v3::LabelId label) { + auto schema_properties = schemas_.at(label); + std::vector<memgraph::storage::v3::SchemaProperty> ret; + std::transform(schema_properties.begin(), schema_properties.end(), std::back_inserter(ret), [](const auto &prop) { + memgraph::storage::v3::SchemaProperty schema_prop = { + .property_id = prop, + // This should not be hardcoded, but for testing purposes it will suffice. + .type = memgraph::common::SchemaType::INT}; + + return schema_prop; + }); + return ret; + } + + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> ExtractPrimaryKey( + storage::v3::LabelId label, std::vector<query::v2::plan::FilterInfo> property_filters) { + MG_ASSERT(schemas_.contains(label), + "You did not specify the Schema for this label! Use FakeDistributedDbAccessor::CreateSchema(...)."); + + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk; + const auto schema = GetSchemaPropertiesForLabel(label); + + std::vector<storage::v3::PropertyId> schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + return pk.size() == schema_properties.size() + ? pk + : std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>>{}; + } + + std::vector<memgraph::storage::v3::PropertyId> GetSchemaPropertiesForLabel(storage::v3::LabelId label) { + return schemas_.at(label); + } + + void CreateSchema(const memgraph::storage::v3::LabelId primary_label, + const std::vector<memgraph::storage::v3::PropertyId> &schemas_types) { + MG_ASSERT(!schemas_.contains(primary_label), "You already created the schema for this label!"); + schemas_.emplace(primary_label, schemas_types); + } + + private: + std::unordered_map<std::string, memgraph::storage::v3::LabelId> primary_labels_; + std::unordered_map<std::string, memgraph::storage::v3::LabelId> secondary_labels_; + std::unordered_map<std::string, memgraph::storage::v3::EdgeTypeId> edge_types_; + std::unordered_map<std::string, memgraph::storage::v3::PropertyId> primary_properties_; + std::unordered_map<std::string, memgraph::storage::v3::PropertyId> secondary_properties_; + + std::unordered_map<memgraph::storage::v3::LabelId, int64_t> label_index_; + std::vector<std::tuple<memgraph::storage::v3::LabelId, memgraph::storage::v3::PropertyId, int64_t>> + label_property_index_; + + std::unordered_map<memgraph::storage::v3::LabelId, std::vector<memgraph::storage::v3::PropertyId>> schemas_; +}; + +} // namespace memgraph::query::v2::plan diff --git a/tests/unit/query_v2_common.hpp b/tests/unit/query_v2_common.hpp new file mode 100644 index 000000000..85905cb22 --- /dev/null +++ b/tests/unit/query_v2_common.hpp @@ -0,0 +1,603 @@ +// Copyright 2022 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstStorage storage; // Macros rely on storage being in scope. +/// // PROPERTY_LOOKUP and PROPERTY_PAIR macros +/// // rely on a DbAccessor *reference* named dba. +/// database::GraphDb db; +/// auto dba_ptr = db.Access(); +/// auto &dba = *dba_ptr; +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. + +#pragma once + +#include <map> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "query/frontend/ast/pretty_print.hpp" // not sure if that is ok... +#include "query/v2/frontend/ast/ast.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/string.hpp" + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { + +namespace test_common { + +auto ToIntList(const TypedValue &t) { + std::vector<int64_t> list; + for (auto x : t.ValueList()) { + list.push_back(x.ValueInt()); + } + return list; +}; + +auto ToIntMap(const TypedValue &t) { + std::map<std::string, int64_t> map; + for (const auto &kv : t.ValueMap()) map.emplace(kv.first, kv.second.ValueInt()); + return map; +}; + +std::string ToString(Expression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +std::string ToString(NamedExpression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, +// so that they can be used to resolve function calls. +struct OrderBy { + std::vector<SortItem> expressions; +}; + +struct Skip { + Expression *expression = nullptr; +}; + +struct Limit { + Expression *expression = nullptr; +}; + +struct OnMatch { + std::vector<Clause *> set; +}; +struct OnCreate { + std::vector<Clause *> set; +}; + +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering = Ordering::ASC) { + order_by.expressions.push_back({ordering, expression}); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +template <class... T> +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + +/// Create PropertyLookup with given name and property. +/// +/// Name is used to create the Identifier which is used for property lookup. +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, const std::string &name, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), + storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, const std::string &property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(property)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, const std::string &name, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), storage.GetPropertyIx(prop_pair.first)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(prop_pair.first)); +} + +/// Create an EdgeAtom with given name, direction and edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdge(AstStorage &storage, const std::string &name, EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + return storage.Create<EdgeAtom>(storage.Create<Identifier>(name), EdgeAtom::Type::SINGLE, dir, types); +} + +/// Create a variable length expansion EdgeAtom with given name, direction and +/// edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdgeVariable(AstStorage &storage, const std::string &name, EdgeAtom::Type type = EdgeAtom::Type::DEPTH_FIRST, + EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}, Identifier *flambda_inner_edge = nullptr, + Identifier *flambda_inner_node = nullptr, Identifier *wlambda_inner_edge = nullptr, + Identifier *wlambda_inner_node = nullptr, Expression *wlambda_expression = nullptr, + Identifier *total_weight = nullptr) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + auto r_val = storage.Create<EdgeAtom>(storage.Create<Identifier>(name), type, dir, types); + + r_val->filter_lambda_.inner_edge = + flambda_inner_edge ? flambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->filter_lambda_.inner_node = + flambda_inner_node ? flambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + + if (type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + r_val->weight_lambda_.inner_edge = + wlambda_inner_edge ? wlambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.inner_node = + wlambda_inner_node ? wlambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.expression = + wlambda_expression ? wlambda_expression : storage.Create<memgraph::query::v2::PrimitiveLiteral>(1); + + r_val->total_weight_ = total_weight; + } + + return r_val; +} + +/// Create a NodeAtom with given name and label. +/// +/// Name is used to create the Identifier which is assigned to the node. +auto GetNode(AstStorage &storage, const std::string &name, std::optional<std::string> label = std::nullopt) { + auto node = storage.Create<NodeAtom>(storage.Create<Identifier>(name)); + if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); + return node; +} + +/// Create a Pattern with given atoms. +auto GetPattern(AstStorage &storage, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(memgraph::utils::RandomString(20), false); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// Create a Pattern with given name and atoms. +auto GetPattern(AstStorage &storage, const std::string &name, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(name, true); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// This function fills an AST node which with given patterns. +/// +/// The function is most commonly used to create Match and Create clauses. +template <class TWithPatterns> +auto GetWithPatterns(TWithPatterns *with_patterns, std::vector<Pattern *> patterns) { + with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); + return with_patterns; +} + +/// Create a query with given clauses. + +auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { + single_query->clauses_.emplace_back(clause); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return single_query; +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where, T *...clauses) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return GetSingleQuery(single_query, clauses...); +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where, T *...clauses) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return GetSingleQuery(single_query, clauses...); +} + +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { + single_query->clauses_.emplace_back(clause); + return GetSingleQuery(single_query, clauses...); +} + +auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { + cypher_union->single_query_ = single_query; + return cypher_union; +} + +auto GetQuery(AstStorage &storage, SingleQuery *single_query) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + return query; +} + +template <class... T> +auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_unions) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + query->cypher_unions_ = std::vector<CypherUnion *>{cypher_unions...}; + return query; +} + +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name) { + if (name == "*") { + body.all_identifiers = true; + } else { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + } +} +void FillReturnBody(AstStorage &, ReturnBody &body, Limit limit) { body.limit = limit.expression; } +void FillReturnBody(AstStorage &, ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, Expression *expr, NamedExpression *named_expr) { + // This overload supports `RETURN(expr, AS(name))` construct, since + // NamedExpression does not inherit Expression. + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, Expression *expr, NamedExpression *named_expr, T... rest) { + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr, + T... rest) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, T... rest) { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} + +/// Create the return clause with given expressions. +/// +/// The supported expression combination of arguments is: +/// +/// (String | NamedExpression | (Expression NamedExpression))+ +/// [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. Taking a String is a shorthand +/// for RETURN(IDENT(string), AS(string), ....). +/// +/// @sa GetWith +template <class... T> +auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { + auto ret = storage.Create<Return>(); + ret->body_.distinct = distinct; + FillReturnBody(storage, ret->body_, exprs...); + return ret; +} + +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn +template <class... T> +auto GetWith(AstStorage &storage, bool distinct, T... exprs) { + auto with = storage.Create<With>(); + with->body_.distinct = distinct; + FillReturnBody(storage, with->body_, exprs...); + return with; +} + +/// Create the UNWIND clause with given named expression. +auto GetUnwind(AstStorage &storage, NamedExpression *named_expr) { + return storage.Create<memgraph::query::v2::Unwind>(named_expr); +} +auto GetUnwind(AstStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + +/// Create the delete clause with given named expressions. +auto GetDelete(AstStorage &storage, std::vector<Expression *> exprs, bool detach = false) { + auto del = storage.Create<Delete>(); + del->expressions_.insert(del->expressions_.begin(), exprs.begin(), exprs.end()); + del->detach_ = detach; + return del; +} + +/// Create a set property clause for given property lookup and the right hand +/// side expression. +auto GetSet(AstStorage &storage, PropertyLookup *prop_lookup, Expression *expr) { + return storage.Create<SetProperty>(prop_lookup, expr); +} + +/// Create a set properties clause for given identifier name and the right hand +/// side expression. +auto GetSet(AstStorage &storage, const std::string &name, Expression *expr, bool update = false) { + return storage.Create<SetProperties>(storage.Create<Identifier>(name), expr, update); +} + +/// Create a set labels clause for given identifier name and labels. +auto GetSet(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<SetLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a remove property clause for given property lookup +auto GetRemove(AstStorage &storage, PropertyLookup *prop_lookup) { return storage.Create<RemoveProperty>(prop_lookup); } + +/// Create a remove labels clause for given identifier name and labels. +auto GetRemove(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate +/// parts. +auto GetMerge(AstStorage &storage, Pattern *pattern, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_create_ = on_create.set; + return merge; +} +auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_match_ = on_match.set; + merge->on_create_ = on_create.set; + return merge; +} + +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector<memgraph::query::v2::Expression *> arguments = {}) { + auto *call_procedure = storage.Create<memgraph::query::v2::CallProcedure>(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::v2::Clause *> &clauses) { + return storage.Create<query::v2::Foreach>(named_expr, clauses); +} + +} // namespace test_common + +} // namespace memgraph::query::v2 + +/// All the following macros implicitly pass `storage` variable to functions. +/// You need to have `AstStorage storage;` somewhere in scope to use them. +/// Refer to function documentation to see what the macro does. +/// +/// Example usage: +/// +/// // Create MATCH (n) -[r]- (m) RETURN m AS new_name +/// AstStorage storage; +/// auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), +/// RETURN(NEXPR("new_name"), IDENT("m"))); +#define NODE(...) memgraph::expr::test_common::GetNode(storage, __VA_ARGS__) +#define EDGE(...) memgraph::expr::test_common::GetEdge(storage, __VA_ARGS__) +#define EDGE_VARIABLE(...) memgraph::expr::test_common::GetEdgeVariable(storage, __VA_ARGS__) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define NAMED_PATTERN(name, ...) memgraph::expr::test_common::GetPattern(storage, name, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(true), {__VA_ARGS__}) +#define MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(), {__VA_ARGS__}) +#define WHERE(expr) storage.Create<memgraph::query::v2::Where>((expr)) +#define CREATE(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Create>(), {__VA_ARGS__}) +#define IDENT(...) storage.Create<memgraph::query::v2::Identifier>(__VA_ARGS__) +#define LITERAL(val) storage.Create<memgraph::query::v2::PrimitiveLiteral>((val)) +#define LIST(...) \ + storage.Create<memgraph::query::v2::ListLiteral>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define MAP(...) \ + storage.Create<memgraph::query::v2::MapLiteral>( \ + std::unordered_map<memgraph::query::v2::PropertyIx, memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define PROPERTY_PAIR(property_name) \ + std::make_pair(property_name, dba.NameToProperty(property_name)) // This one might not be needed at all +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) +#define PROPERTY_LOOKUP(...) memgraph::expr::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) +#define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::v2::ParameterLookup>((token_position)) +#define NEXPR(name, expr) storage.Create<memgraph::query::v2::NamedExpression>((name), (expr)) +// AS is alternative to NEXPR which does not initialize NamedExpression with +// Expression. It should be used with RETURN or WITH. For example: +// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). +#define AS(name) storage.Create<memgraph::query::v2::NamedExpression>((name)) +#define RETURN(...) memgraph::expr::test_common::GetReturn(storage, false, __VA_ARGS__) +#define WITH(...) memgraph::expr::test_common::GetWith(storage, false, __VA_ARGS__) +#define RETURN_DISTINCT(...) memgraph::expr::test_common::GetReturn(storage, true, __VA_ARGS__) +#define WITH_DISTINCT(...) memgraph::expr::test_common::GetWith(storage, true, __VA_ARGS__) +#define UNWIND(...) memgraph::expr::test_common::GetUnwind(storage, __VA_ARGS__) +#define ORDER_BY(...) memgraph::expr::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + memgraph::expr::test_common::Skip { (expr) } +#define LIMIT(expr) \ + memgraph::expr::test_common::Limit { (expr) } +#define DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}) +#define DETACH_DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}, true) +#define SET(...) memgraph::expr::test_common::GetSet(storage, __VA_ARGS__) +#define REMOVE(...) memgraph::expr::test_common::GetRemove(storage, __VA_ARGS__) +#define MERGE(...) memgraph::expr::test_common::GetMerge(storage, __VA_ARGS__) +#define ON_MATCH(...) \ + memgraph::expr::test_common::OnMatch { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define ON_CREATE(...) \ + memgraph::expr::test_common::OnCreate { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define CREATE_INDEX_ON(label, property) \ + storage.Create<memgraph::query::v2::IndexQuery>(memgraph::query::v2::IndexQuery::Action::CREATE, (label), \ + std::vector<memgraph::query::v2::PropertyIx>{(property)}) +#define QUERY(...) memgraph::expr::test_common::GetQuery(storage, __VA_ARGS__) +#define SINGLE_QUERY(...) memgraph::expr::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__) +#define UNION(...) memgraph::expr::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__) +#define UNION_ALL(...) memgraph::expr::test_common::GetCypherUnion(storage.Create<CypherUnion>(false), __VA_ARGS__) +#define FOREACH(...) memgraph::expr::test_common::GetForeach(storage, __VA_ARGS__) +// Various operators +#define NOT(expr) storage.Create<memgraph::query::v2::NotOperator>((expr)) +#define UPLUS(expr) storage.Create<memgraph::query::v2::UnaryPlusOperator>((expr)) +#define UMINUS(expr) storage.Create<memgraph::query::v2::UnaryMinusOperator>((expr)) +#define IS_NULL(expr) storage.Create<memgraph::query::v2::IsNullOperator>((expr)) +#define ADD(expr1, expr2) storage.Create<memgraph::query::v2::AdditionOperator>((expr1), (expr2)) +#define LESS(expr1, expr2) storage.Create<memgraph::query::v2::LessOperator>((expr1), (expr2)) +#define LESS_EQ(expr1, expr2) storage.Create<memgraph::query::v2::LessEqualOperator>((expr1), (expr2)) +#define GREATER(expr1, expr2) storage.Create<memgraph::query::v2::GreaterOperator>((expr1), (expr2)) +#define GREATER_EQ(expr1, expr2) storage.Create<memgraph::query::v2::GreaterEqualOperator>((expr1), (expr2)) +#define SUM(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::AVG) +#define COLLECT_LIST(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COLLECT_LIST) +#define EQ(expr1, expr2) storage.Create<memgraph::query::v2::EqualOperator>((expr1), (expr2)) +#define NEQ(expr1, expr2) storage.Create<memgraph::query::v2::NotEqualOperator>((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create<memgraph::query::v2::AndOperator>((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create<memgraph::query::v2::OrOperator>((expr1), (expr2)) +#define IN_LIST(expr1, expr2) storage.Create<memgraph::query::v2::InListOperator>((expr1), (expr2)) +#define IF(cond, then, else) storage.Create<memgraph::query::v2::IfOperator>((cond), (then), (else)) +// Function call +#define FN(function_name, ...) \ + storage.Create<memgraph::query::v2::Function>(memgraph::utils::ToUpperCase(function_name), \ + std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +// List slicing +#define SLICE(list, lower_bound, upper_bound) \ + storage.Create<memgraph::query::v2::ListSlicingOperator>(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create<memgraph::query::v2::All>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define SINGLE(variable, list, where) \ + storage.Create<memgraph::query::v2::Single>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define ANY(variable, list, where) \ + storage.Create<memgraph::query::v2::Any>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define NONE(variable, list, where) \ + storage.Create<memgraph::query::v2::None>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define REDUCE(accumulator, initializer, variable, list, expr) \ + storage.Create<memgraph::query::v2::Reduce>(storage.Create<memgraph::query::v2::Identifier>(accumulator), \ + initializer, storage.Create<memgraph::query::v2::Identifier>(variable), \ + list, expr) +#define COALESCE(...) \ + storage.Create<memgraph::query::v2::Coalesce>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define EXTRACT(variable, list, expr) \ + storage.Create<memgraph::query::v2::Extract>(storage.Create<memgraph::query::v2::Identifier>(variable), list, expr) +#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ + storage.Create<memgraph::query::v2::AuthQuery>((action), (user), (role), (user_or_role), password, (privileges)) +#define DROP_USER(usernames) storage.Create<memgraph::query::v2::DropUser>((usernames)) +#define CALL_PROCEDURE(...) memgraph::query::v2::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_v2_expression_evaluator.cpp b/tests/unit/query_v2_expression_evaluator.cpp index 5e91d0d5a..b941d3463 100644 --- a/tests/unit/query_v2_expression_evaluator.cpp +++ b/tests/unit/query_v2_expression_evaluator.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -125,6 +125,11 @@ class MockedRequestRouter : public RequestRouterInterface { bool IsPrimaryKey(LabelId primary_label, PropertyId property) const override { return true; } + const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId /*label*/) const override { + static std::vector<coordinator::SchemaProperty> schema; + return schema; + }; + private: void SetUpNameIdMappers() { std::unordered_map<uint64_t, std::string> id_to_name; diff --git a/tests/unit/query_v2_plan.cpp b/tests/unit/query_v2_plan.cpp new file mode 100644 index 000000000..8fc5e11c7 --- /dev/null +++ b/tests/unit/query_v2_plan.cpp @@ -0,0 +1,178 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query_plan_checker_v2.hpp" + +#include <iostream> +#include <list> +#include <sstream> +#include <tuple> +#include <typeinfo> +#include <unordered_set> +#include <variant> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "expr/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/planner.hpp" + +#include "query_v2_common.hpp" + +namespace memgraph::query { +::std::ostream &operator<<(::std::ostream &os, const Symbol &sym) { + return os << "Symbol{\"" << sym.name() << "\" [" << sym.position() << "] " << Symbol::TypeToString(sym.type()) << "}"; +} +} // namespace memgraph::query + +// using namespace memgraph::query::v2::plan; +using namespace memgraph::expr::plan; +using memgraph::query::Symbol; +using memgraph::query::SymbolGenerator; +using memgraph::query::v2::AstStorage; +using memgraph::query::v2::SingleQuery; +using memgraph::query::v2::SymbolTable; +using Type = memgraph::query::v2::EdgeAtom::Type; +using Direction = memgraph::query::v2::EdgeAtom::Direction; +using Bound = ScanAllByLabelPropertyRange::Bound; + +namespace { + +class Planner { + public: + template <class TDbAccessor> + Planner(std::vector<SingleQueryPart> single_query_parts, PlanningContext<TDbAccessor> context) { + memgraph::expr::Parameters parameters; + PostProcessor post_processor(parameters); + plan_ = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(single_query_parts, &context); + plan_ = post_processor.Rewrite(std::move(plan_), &context); + } + + auto &plan() { return *plan_; } + + private: + std::unique_ptr<LogicalOperator> plan_; +}; + +template <class... TChecker> +auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker... checker) { + std::list<BaseOpChecker *> checkers{&checker...}; + PlanChecker plan_checker(checkers, symbol_table); + plan.Accept(plan_checker); + EXPECT_TRUE(plan_checker.checkers_.empty()); +} + +template <class TPlanner, class... TChecker> +auto CheckPlan(memgraph::query::v2::CypherQuery *query, AstStorage &storage, TChecker... checker) { + auto symbol_table = memgraph::expr::MakeSymbolTable(query); + FakeDistributedDbAccessor dba; + auto planner = MakePlanner<TPlanner>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, checker...); +} + +template <class T> +class TestPlanner : public ::testing::Test {}; + +using PlannerTypes = ::testing::Types<Planner>; + +TYPED_TEST_CASE(TestPlanner, PlannerTypes); + +TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) { + const char *prim_label_name = "prim_label_one"; + // Exact primary key match, one elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second}); + + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // Exact primary key match, two elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(AND(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1)), + EQ(PROPERTY_LOOKUP("n", prim_prop_two), LITERAL(1)))), + RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // One elem is missing from PK, default to ScanAllByLabelPropertyValue. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, prim_prop_one, IDENT("n")), + ExpectProduce()); + } +} + +} // namespace