diff --git a/src/expr/interpret/frame.hpp b/src/expr/interpret/frame.hpp index 9f4068226..457806680 100644 --- a/src/expr/interpret/frame.hpp +++ b/src/expr/interpret/frame.hpp @@ -23,9 +23,9 @@ namespace memgraph::expr { class Frame { public: /// Create a Frame of given size backed by a utils::NewDeleteResource() - explicit Frame(int64_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } + explicit Frame(size_t size) : elems_(size, utils::NewDeleteResource()) { MG_ASSERT(size >= 0); } - Frame(int64_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } + Frame(size_t size, utils::MemoryResource *memory) : elems_(size, memory) { MG_ASSERT(size >= 0); } TypedValue &operator[](const Symbol &symbol) { return elems_[symbol.position()]; } const TypedValue &operator[](const Symbol &symbol) const { return elems_[symbol.position()]; } @@ -44,9 +44,9 @@ class Frame { class FrameWithValidity final : public Frame { public: - explicit FrameWithValidity(int64_t size) : Frame(size), is_valid_(false) {} + explicit FrameWithValidity(size_t size) : Frame(size), is_valid_(false) {} - FrameWithValidity(int64_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {} + FrameWithValidity(size_t size, utils::MemoryResource *memory) : Frame(size, memory), is_valid_(false) {} bool IsValid() const noexcept { return is_valid_; } void MakeValid() noexcept { is_valid_ = true; } diff --git a/src/io/local_transport/local_system.hpp b/src/io/local_transport/local_system.hpp index 2e54f8d75..7b0cda537 100644 --- a/src/io/local_transport/local_system.hpp +++ b/src/io/local_transport/local_system.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/io/simulator/simulator.hpp b/src/io/simulator/simulator.hpp index 8afc073af..4ff8e7ae0 100644 --- a/src/io/simulator/simulator.hpp +++ b/src/io/simulator/simulator.hpp @@ -41,7 +41,7 @@ class Simulator { Io<SimulatorTransport> Register(Address address) { std::uniform_int_distribution<uint64_t> seed_distrib; uint64_t seed = seed_distrib(rng_); - return Io{SimulatorTransport{simulator_handle_, address, seed}, address}; + return Io{SimulatorTransport(simulator_handle_, address, seed), address}; } void IncrementServerCountAndWaitForQuiescentState(Address address) { @@ -50,8 +50,12 @@ class Simulator { SimulatorStats Stats() { return simulator_handle_->Stats(); } + std::shared_ptr<SimulatorHandle> GetSimulatorHandle() const { return simulator_handle_; } + std::function<bool()> GetSimulatorTickClosure() { - std::function<bool()> tick_closure = [handle_copy = simulator_handle_] { return handle_copy->MaybeTickSimulator(); }; + std::function<bool()> tick_closure = [handle_copy = simulator_handle_] { + return handle_copy->MaybeTickSimulator(); + }; return tick_closure; } }; diff --git a/src/io/simulator/simulator_transport.hpp b/src/io/simulator/simulator_transport.hpp index 038cfeb03..1272a04a1 100644 --- a/src/io/simulator/simulator_transport.hpp +++ b/src/io/simulator/simulator_transport.hpp @@ -26,7 +26,7 @@ using memgraph::io::Time; class SimulatorTransport { std::shared_ptr<SimulatorHandle> simulator_handle_; - const Address address_; + Address address_; std::mt19937 rng_; public: @@ -36,7 +36,9 @@ class SimulatorTransport { template <Message RequestT, Message ResponseT> ResponseFuture<ResponseT> Request(Address to_address, Address from_address, RequestT request, std::function<void()> notification, Duration timeout) { - std::function<bool()> tick_simulator = [handle_copy = simulator_handle_] { return handle_copy->MaybeTickSimulator(); }; + std::function<bool()> tick_simulator = [handle_copy = simulator_handle_] { + return handle_copy->MaybeTickSimulator(); + }; return simulator_handle_->template SubmitRequest<RequestT, ResponseT>( to_address, from_address, std::move(request), timeout, std::move(tick_simulator), std::move(notification)); diff --git a/src/memgraph.cpp b/src/memgraph.cpp index d825cc0e7..35fd20ad7 100644 --- a/src/memgraph.cpp +++ b/src/memgraph.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -640,6 +640,8 @@ int main(int argc, char **argv) { memgraph::machine_manager::MachineManager<memgraph::io::local_transport::LocalTransport> mm{io, config, coordinator}; std::jthread mm_thread([&mm] { mm.Run(); }); + auto rr_factory = std::make_unique<memgraph::query::v2::LocalRequestRouterFactory>(io); + memgraph::query::v2::InterpreterContext interpreter_context{ (memgraph::storage::v3::Shard *)(nullptr), {.query = {.allow_load_csv = FLAGS_allow_load_csv}, @@ -650,7 +652,7 @@ int main(int argc, char **argv) { .stream_transaction_conflict_retries = FLAGS_stream_transaction_conflict_retries, .stream_transaction_retry_interval = std::chrono::milliseconds(FLAGS_stream_transaction_retry_interval)}, FLAGS_data_directory, - std::move(io), + std::move(rr_factory), mm.CoordinatorAddress()}; SessionData session_data{&interpreter_context}; diff --git a/src/query/v2/context.hpp b/src/query/v2/context.hpp index cb30a9ced..58f5ada97 100644 --- a/src/query/v2/context.hpp +++ b/src/query/v2/context.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/src/query/v2/interpreter.cpp b/src/query/v2/interpreter.cpp index c7943e6df..aa220d764 100644 --- a/src/query/v2/interpreter.cpp +++ b/src/query/v2/interpreter.cpp @@ -907,34 +907,24 @@ using RWType = plan::ReadWriteTypeChecker::RWType; InterpreterContext::InterpreterContext(storage::v3::Shard *db, const InterpreterConfig config, const std::filesystem::path & /*data_directory*/, - io::Io<io::local_transport::LocalTransport> io, + std::unique_ptr<RequestRouterFactory> request_router_factory, coordinator::Address coordinator_addr) - : db(db), config(config), io{std::move(io)}, coordinator_address{coordinator_addr} {} + : db(db), + config(config), + coordinator_address{coordinator_addr}, + request_router_factory_{std::move(request_router_factory)} {} Interpreter::Interpreter(InterpreterContext *interpreter_context) : interpreter_context_(interpreter_context) { MG_ASSERT(interpreter_context_, "Interpreter context must not be NULL"); - // TODO(tyler) make this deterministic so that it can be tested. - auto random_uuid = boost::uuids::uuid{boost::uuids::random_generator()()}; - auto query_io = interpreter_context_->io.ForkLocal(random_uuid); + request_router_ = + interpreter_context_->request_router_factory_->CreateRequestRouter(interpreter_context_->coordinator_address); - request_router_ = std::make_unique<RequestRouter<io::local_transport::LocalTransport>>( - coordinator::CoordinatorClient<io::local_transport::LocalTransport>( - query_io, interpreter_context_->coordinator_address, std::vector{interpreter_context_->coordinator_address}), - std::move(query_io)); // Get edge ids - coordinator::CoordinatorWriteRequests requests{coordinator::AllocateEdgeIdBatchRequest{.batch_size = 1000000}}; - io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests> ww; - ww.operation = requests; - auto resp = interpreter_context_->io - .Request<io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests>, - io::rsm::WriteResponse<coordinator::CoordinatorWriteResponses>>( - interpreter_context_->coordinator_address, ww) - .Wait(); - if (resp.HasValue()) { - const auto alloc_edge_id_reps = - std::get<coordinator::AllocateEdgeIdBatchResponse>(resp.GetValue().message.write_return); - interpreter_context_->edge_ids_alloc = {alloc_edge_id_reps.low, alloc_edge_id_reps.high}; + const auto edge_ids_alloc_min_max_pair = + request_router_->AllocateInitialEdgeIds(interpreter_context_->coordinator_address); + if (edge_ids_alloc_min_max_pair) { + interpreter_context_->edge_ids_alloc = {edge_ids_alloc_min_max_pair->first, edge_ids_alloc_min_max_pair->second}; } } diff --git a/src/query/v2/interpreter.hpp b/src/query/v2/interpreter.hpp index 985c9a90c..4efc85c22 100644 --- a/src/query/v2/interpreter.hpp +++ b/src/query/v2/interpreter.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -16,7 +16,6 @@ #include "coordinator/coordinator.hpp" #include "coordinator/coordinator_client.hpp" -#include "io/local_transport/local_transport.hpp" #include "io/transport.hpp" #include "query/v2/auth_checker.hpp" #include "query/v2/bindings/cypher_main_visitor.hpp" @@ -172,7 +171,8 @@ struct PreparedQuery { struct InterpreterContext { explicit InterpreterContext(storage::v3::Shard *db, InterpreterConfig config, const std::filesystem::path &data_directory, - io::Io<io::local_transport::LocalTransport> io, coordinator::Address coordinator_addr); + std::unique_ptr<RequestRouterFactory> request_router_factory, + coordinator::Address coordinator_addr); storage::v3::Shard *db; @@ -188,26 +188,24 @@ struct InterpreterContext { const InterpreterConfig config; IdAllocator edge_ids_alloc; - // TODO (antaljanosbenjamin) Figure out an abstraction for io::Io to make it possible to construct an interpreter - // context with a simulator transport without templatizing it. - io::Io<io::local_transport::LocalTransport> io; coordinator::Address coordinator_address; + std::unique_ptr<RequestRouterFactory> request_router_factory_; storage::v3::LabelId NameToLabelId(std::string_view label_name) { - return storage::v3::LabelId::FromUint(query_id_mapper.NameToId(label_name)); + return storage::v3::LabelId::FromUint(query_id_mapper_.NameToId(label_name)); } storage::v3::PropertyId NameToPropertyId(std::string_view property_name) { - return storage::v3::PropertyId::FromUint(query_id_mapper.NameToId(property_name)); + return storage::v3::PropertyId::FromUint(query_id_mapper_.NameToId(property_name)); } storage::v3::EdgeTypeId NameToEdgeTypeId(std::string_view edge_type_name) { - return storage::v3::EdgeTypeId::FromUint(query_id_mapper.NameToId(edge_type_name)); + return storage::v3::EdgeTypeId::FromUint(query_id_mapper_.NameToId(edge_type_name)); } private: // TODO Replace with local map of labels, properties and edge type ids - storage::v3::NameIdMapper query_id_mapper; + storage::v3::NameIdMapper query_id_mapper_; }; /// Function that is used to tell all active interpreters that they should stop @@ -297,12 +295,15 @@ class Interpreter final { void Abort(); const RequestRouterInterface *GetRequestRouter() const { return request_router_.get(); } + void InstallSimulatorTicker(std::function<bool()> &&tick_simulator) { + request_router_->InstallSimulatorTicker(tick_simulator); + } private: struct QueryExecution { - std::optional<PreparedQuery> prepared_query; utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize}; utils::ResourceWithOutOfMemoryException execution_memory_with_exception{&execution_memory}; + std::optional<PreparedQuery> prepared_query; std::map<std::string, TypedValue> summary; std::vector<Notification> notifications; diff --git a/src/query/v2/plan/cost_estimator.hpp b/src/query/v2/plan/cost_estimator.hpp index 8cba2505f..f497d14d5 100644 --- a/src/query/v2/plan/cost_estimator.hpp +++ b/src/query/v2/plan/cost_estimator.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -149,8 +149,6 @@ class CostEstimator : public HierarchicalLogicalOperatorVisitor { return true; } - // TODO: Cost estimate ScanAllById? - // For the given op first increments the cardinality and then cost. #define POST_VISIT_CARD_FIRST(NAME) \ bool PostVisit(NAME &) override { \ diff --git a/src/query/v2/plan/operator.cpp b/src/query/v2/plan/operator.cpp index 64a8c6f8c..ee166e3cc 100644 --- a/src/query/v2/plan/operator.cpp +++ b/src/query/v2/plan/operator.cpp @@ -34,6 +34,7 @@ #include "query/v2/bindings/eval.hpp" #include "query/v2/bindings/symbol_table.hpp" #include "query/v2/context.hpp" +#include "query/v2/conversions.hpp" #include "query/v2/db_accessor.hpp" #include "query/v2/exceptions.hpp" #include "query/v2/frontend/ast/ast.hpp" @@ -93,7 +94,7 @@ extern const Event ScanAllByLabelOperator; extern const Event ScanAllByLabelPropertyRangeOperator; extern const Event ScanAllByLabelPropertyValueOperator; extern const Event ScanAllByLabelPropertyOperator; -extern const Event ScanAllByIdOperator; +extern const Event ScanByPrimaryKeyOperator; extern const Event ExpandOperator; extern const Event ExpandVariableOperator; extern const Event ConstructNamedPathOperator; @@ -589,6 +590,88 @@ class DistributedScanAllAndFilterCursor : public Cursor { ValidFramesConsumer::Iterator own_frames_it_; }; +class DistributedScanByPrimaryKeyCursor : public Cursor { + public: + explicit DistributedScanByPrimaryKeyCursor(Symbol output_symbol, UniqueCursorPtr input_cursor, const char *op_name, + storage::v3::LabelId label, + std::optional<std::vector<Expression *>> filter_expressions, + std::vector<Expression *> primary_key) + : output_symbol_(output_symbol), + input_cursor_(std::move(input_cursor)), + op_name_(op_name), + label_(label), + filter_expressions_(filter_expressions), + primary_key_(primary_key) {} + + enum class State : int8_t { INITIALIZING, COMPLETED }; + + using VertexAccessor = accessors::VertexAccessor; + + std::optional<VertexAccessor> MakeRequestSingleFrame(Frame &frame, RequestRouterInterface &request_router, + ExecutionContext &context) { + // Evaluate the expressions that hold the PrimaryKey. + ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.request_router, + storage::v3::View::NEW); + + std::vector<msgs::Value> pk; + for (auto *primary_property : primary_key_) { + pk.push_back(TypedValueToValue(primary_property->Accept(evaluator))); + } + + msgs::Label label = {.id = msgs::LabelId::FromUint(label_.AsUint())}; + + msgs::GetPropertiesRequest req = {.vertex_ids = {std::make_pair(label, pk)}}; + auto get_prop_result = std::invoke([&context, &request_router, &req]() mutable { + SCOPED_REQUEST_WAIT_PROFILE; + return request_router.GetProperties(req); + }); + MG_ASSERT(get_prop_result.size() <= 1); + + if (get_prop_result.empty()) { + return std::nullopt; + } + auto properties = get_prop_result[0].props; + // TODO (gvolfing) figure out labels when relevant. + msgs::Vertex vertex = {.id = get_prop_result[0].vertex, .labels = {}}; + + return VertexAccessor(vertex, properties, &request_router); + } + + bool Pull(Frame &frame, ExecutionContext &context) override { + SCOPED_PROFILE_OP(op_name_); + + if (MustAbort(context)) { + throw HintedAbortError(); + } + + while (input_cursor_->Pull(frame, context)) { + auto &request_router = *context.request_router; + auto vertex = MakeRequestSingleFrame(frame, request_router, context); + if (vertex) { + frame[output_symbol_] = TypedValue(std::move(*vertex)); + return true; + } + } + return false; + } + + void PullMultiple(MultiFrame & /*input_multi_frame*/, ExecutionContext & /*context*/) override { + throw utils::NotYetImplemented("Multiframe version of ScanByPrimaryKey is yet to be implemented."); + }; + + void Reset() override { input_cursor_->Reset(); } + + void Shutdown() override { input_cursor_->Shutdown(); } + + private: + const Symbol output_symbol_; + const UniqueCursorPtr input_cursor_; + const char *op_name_; + storage::v3::LabelId label_; + std::optional<std::vector<Expression *>> filter_expressions_; + std::vector<Expression *> primary_key_; +}; + ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, storage::v3::View view) : input_(input ? input : std::make_shared<Once>()), output_symbol_(output_symbol), view_(view) {} @@ -683,22 +766,21 @@ UniqueCursorPtr ScanAllByLabelProperty::MakeCursor(utils::MemoryResource *mem) c throw QueryRuntimeException("ScanAllByLabelProperty is not supported"); } -ScanAllById::ScanAllById(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, Expression *expression, - storage::v3::View view) - : ScanAll(input, output_symbol, view), expression_(expression) { - MG_ASSERT(expression); +ScanByPrimaryKey::ScanByPrimaryKey(const std::shared_ptr<LogicalOperator> &input, Symbol output_symbol, + storage::v3::LabelId label, std::vector<query::v2::Expression *> primary_key, + storage::v3::View view) + : ScanAll(input, output_symbol, view), label_(label), primary_key_(primary_key) { + MG_ASSERT(primary_key.front()); } -ACCEPT_WITH_INPUT(ScanAllById) +ACCEPT_WITH_INPUT(ScanByPrimaryKey) -UniqueCursorPtr ScanAllById::MakeCursor(utils::MemoryResource *mem) const { - EventCounter::IncrementCounter(EventCounter::ScanAllByIdOperator); - // TODO Reimplement when we have reliable conversion between hash value and pk - auto vertices = [](Frame & /*frame*/, ExecutionContext & /*context*/) -> std::optional<std::vector<VertexAccessor>> { - return std::nullopt; - }; - return MakeUniqueCursorPtr<ScanAllCursor<decltype(vertices)>>(mem, output_symbol_, input_->MakeCursor(mem), - std::move(vertices), "ScanAllById"); +UniqueCursorPtr ScanByPrimaryKey::MakeCursor(utils::MemoryResource *mem) const { + EventCounter::IncrementCounter(EventCounter::ScanByPrimaryKeyOperator); + + return MakeUniqueCursorPtr<DistributedScanByPrimaryKeyCursor>(mem, output_symbol_, input_->MakeCursor(mem), + "ScanByPrimaryKey", label_, + std::nullopt /*filter_expressions*/, primary_key_); } Expand::Expand(const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, Symbol node_symbol, diff --git a/src/query/v2/plan/operator.lcp b/src/query/v2/plan/operator.lcp index 91726752b..4f34cc061 100644 --- a/src/query/v2/plan/operator.lcp +++ b/src/query/v2/plan/operator.lcp @@ -127,7 +127,7 @@ class ScanAllByLabel; class ScanAllByLabelPropertyRange; class ScanAllByLabelPropertyValue; class ScanAllByLabelProperty; -class ScanAllById; +class ScanByPrimaryKey; class Expand; class ExpandVariable; class ConstructNamedPath; @@ -158,7 +158,7 @@ class Foreach; using LogicalOperatorCompositeVisitor = utils::CompositeVisitor< Once, CreateNode, CreateExpand, ScanAll, ScanAllByLabel, ScanAllByLabelPropertyRange, ScanAllByLabelPropertyValue, - ScanAllByLabelProperty, ScanAllById, + ScanAllByLabelProperty, ScanByPrimaryKey, Expand, ExpandVariable, ConstructNamedPath, Filter, Produce, Delete, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, EdgeUniquenessFilter, Accumulate, Aggregate, Skip, Limit, OrderBy, Merge, @@ -859,19 +859,21 @@ given label and property. (:serialize (:slk)) (:clone)) - - -(lcp:define-class scan-all-by-id (scan-all) - ((expression "Expression *" :scope :public +(lcp:define-class scan-by-primary-key (scan-all) + ((label "::storage::v3::LabelId" :scope :public) + (primary-key "std::vector<Expression*>" :scope :public) + (expression "Expression *" :scope :public :slk-save #'slk-save-ast-pointer :slk-load (slk-load-ast-pointer "Expression"))) (:documentation - "ScanAll producing a single node with ID equal to evaluated expression") + "ScanAll producing a single node with specified by the label and primary key") (:public #>cpp - ScanAllById() {} - ScanAllById(const std::shared_ptr<LogicalOperator> &input, - Symbol output_symbol, Expression *expression, + ScanByPrimaryKey() {} + ScanByPrimaryKey(const std::shared_ptr<LogicalOperator> &input, + Symbol output_symbol, + storage::v3::LabelId label, + std::vector<query::v2::Expression*> primary_key, storage::v3::View view = storage::v3::View::OLD); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; diff --git a/src/query/v2/plan/pretty_print.cpp b/src/query/v2/plan/pretty_print.cpp index bc30a3890..7eb1dd5a9 100644 --- a/src/query/v2/plan/pretty_print.cpp +++ b/src/query/v2/plan/pretty_print.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -86,10 +86,10 @@ bool PlanPrinter::PreVisit(query::v2::plan::ScanAllByLabelProperty &op) { return true; } -bool PlanPrinter::PreVisit(ScanAllById &op) { +bool PlanPrinter::PreVisit(query::v2::plan::ScanByPrimaryKey &op) { WithPrintLn([&](auto &out) { - out << "* ScanAllById" - << " (" << op.output_symbol_.name() << ")"; + out << "* ScanByPrimaryKey" + << " (" << op.output_symbol_.name() << " :" << request_router_->LabelToName(op.label_) << ")"; }); return true; } @@ -487,12 +487,15 @@ bool PlanToJsonVisitor::PreVisit(ScanAllByLabelProperty &op) { return false; } -bool PlanToJsonVisitor::PreVisit(ScanAllById &op) { +bool PlanToJsonVisitor::PreVisit(ScanByPrimaryKey &op) { json self; - self["name"] = "ScanAllById"; + self["name"] = "ScanByPrimaryKey"; + self["label"] = ToJson(op.label_, *request_router_); self["output_symbol"] = ToJson(op.output_symbol_); + op.input_->Accept(*this); self["input"] = PopOutput(); + output_ = std::move(self); return false; } diff --git a/src/query/v2/plan/pretty_print.hpp b/src/query/v2/plan/pretty_print.hpp index 4094d7c81..31ff17b18 100644 --- a/src/query/v2/plan/pretty_print.hpp +++ b/src/query/v2/plan/pretty_print.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -67,7 +67,7 @@ class PlanPrinter : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; @@ -194,7 +194,7 @@ class PlanToJsonVisitor : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Produce &) override; bool PreVisit(Accumulate &) override; diff --git a/src/query/v2/plan/read_write_type_checker.cpp b/src/query/v2/plan/read_write_type_checker.cpp index 6cc38cedf..28c1e8d0a 100644 --- a/src/query/v2/plan/read_write_type_checker.cpp +++ b/src/query/v2/plan/read_write_type_checker.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -11,10 +11,12 @@ #include "query/v2/plan/read_write_type_checker.hpp" -#define PRE_VISIT(TOp, RWType, continue_visiting) \ - bool ReadWriteTypeChecker::PreVisit(TOp &op) { \ - UpdateType(RWType); \ - return continue_visiting; \ +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define PRE_VISIT(TOp, RWType, continue_visiting) \ + /*NOLINTNEXTLINE(bugprone-macro-parentheses)*/ \ + bool ReadWriteTypeChecker::PreVisit(TOp & /*op*/) { \ + UpdateType(RWType); \ + return continue_visiting; \ } namespace memgraph::query::v2::plan { @@ -35,7 +37,7 @@ PRE_VISIT(ScanAllByLabel, RWType::R, true) PRE_VISIT(ScanAllByLabelPropertyRange, RWType::R, true) PRE_VISIT(ScanAllByLabelPropertyValue, RWType::R, true) PRE_VISIT(ScanAllByLabelProperty, RWType::R, true) -PRE_VISIT(ScanAllById, RWType::R, true) +PRE_VISIT(ScanByPrimaryKey, RWType::R, true) PRE_VISIT(Expand, RWType::R, true) PRE_VISIT(ExpandVariable, RWType::R, true) diff --git a/src/query/v2/plan/read_write_type_checker.hpp b/src/query/v2/plan/read_write_type_checker.hpp index a3c2f1a46..626cf7af3 100644 --- a/src/query/v2/plan/read_write_type_checker.hpp +++ b/src/query/v2/plan/read_write_type_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -59,7 +59,7 @@ class ReadWriteTypeChecker : public virtual HierarchicalLogicalOperatorVisitor { bool PreVisit(ScanAllByLabelPropertyValue &) override; bool PreVisit(ScanAllByLabelPropertyRange &) override; bool PreVisit(ScanAllByLabelProperty &) override; - bool PreVisit(ScanAllById &) override; + bool PreVisit(ScanByPrimaryKey & /*unused*/) override; bool PreVisit(Expand &) override; bool PreVisit(ExpandVariable &) override; diff --git a/src/query/v2/plan/rewrite/index_lookup.hpp b/src/query/v2/plan/rewrite/index_lookup.hpp index 57ddba54e..17996d952 100644 --- a/src/query/v2/plan/rewrite/index_lookup.hpp +++ b/src/query/v2/plan/rewrite/index_lookup.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,8 +25,10 @@ #include <gflags/gflags.h> +#include "query/v2/frontend/ast/ast.hpp" #include "query/v2/plan/operator.hpp" #include "query/v2/plan/preprocess.hpp" +#include "storage/v3/id_types.hpp" DECLARE_int64(query_vertex_count_to_expand_existing); @@ -271,11 +273,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { return true; } - bool PreVisit(ScanAllById &op) override { + bool PreVisit(ScanByPrimaryKey &op) override { prev_ops_.push_back(&op); return true; } - bool PostVisit(ScanAllById &) override { + + bool PostVisit(ScanByPrimaryKey & /*unused*/) override { prev_ops_.pop_back(); return true; } @@ -487,6 +490,12 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { storage::v3::PropertyId GetProperty(PropertyIx prop) { return db_->NameToProperty(prop.name); } + void EraseLabelFilters(const memgraph::query::v2::Symbol &node_symbol, memgraph::query::v2::LabelIx prim_label) { + std::vector<query::v2::Expression *> removed_expressions; + filters_.EraseLabelFilter(node_symbol, prim_label, &removed_expressions); + filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + } + std::optional<LabelIx> FindBestLabelIndex(const std::unordered_set<LabelIx> &labels) { MG_ASSERT(!labels.empty(), "Trying to find the best label without any labels."); std::optional<LabelIx> best_label; @@ -559,31 +568,81 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { const auto &view = scan.view_; const auto &modified_symbols = scan.ModifiedSymbols(*symbol_table_); std::unordered_set<Symbol> bound_symbols(modified_symbols.begin(), modified_symbols.end()); - auto are_bound = [&bound_symbols](const auto &used_symbols) { - for (const auto &used_symbol : used_symbols) { - if (!utils::Contains(bound_symbols, used_symbol)) { - return false; - } - } - return true; - }; - // First, try to see if we can find a vertex by ID. - if (!max_vertex_count || *max_vertex_count >= 1) { - for (const auto &filter : filters_.IdFilters(node_symbol)) { - if (filter.id_filter->is_symbol_in_value_ || !are_bound(filter.used_symbols)) continue; - auto *value = filter.id_filter->value_; - filter_exprs_for_removal_.insert(filter.expression); - filters_.EraseFilter(filter); - return std::make_unique<ScanAllById>(input, node_symbol, value, view); - } - } - // Now try to see if we can use label+property index. If not, try to use - // just the label index. + + // Try to see if we can use label + primary-key or label + property index. + // If not, try to use just the label index. const auto labels = filters_.FilteredLabels(node_symbol); if (labels.empty()) { // Without labels, we cannot generate any indexed ScanAll. return nullptr; } + + // First, try to see if we can find a vertex based on the possibly + // supplied primary key. + auto property_filters = filters_.PropertyFilters(node_symbol); + query::v2::LabelIx prim_label; + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> primary_key; + + auto extract_primary_key = [this](storage::v3::LabelId label, + std::vector<query::v2::plan::FilterInfo> property_filters) + -> std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> { + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk_temp; + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk; + std::vector<memgraph::storage::v3::SchemaProperty> schema = db_->GetSchemaForLabel(label); + + std::vector<storage::v3::PropertyId> schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem.property_id; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = db_->NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk_temp.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + // Make sure pk is in the same order as schema_properties. + for (const auto &schema_prop : schema_properties) { + for (auto &pk_temp_prop : pk_temp) { + const auto &property_id = db_->NameToProperty(pk_temp_prop.second.property_filter->property_.name); + if (schema_prop == property_id) { + pk.push_back(pk_temp_prop); + } + } + } + MG_ASSERT(pk.size() == pk_temp.size(), + "The two vectors should represent the same primary key with a possibly different order of contained " + "elements."); + + return pk.size() == schema_properties.size() + ? pk + : std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>>{}; + }; + + if (!property_filters.empty()) { + for (const auto &label : labels) { + if (db_->PrimaryLabelExists(GetLabel(label))) { + prim_label = label; + primary_key = extract_primary_key(GetLabel(prim_label), property_filters); + break; + } + } + if (!primary_key.empty()) { + // Mark the expressions so they won't be used for an additional, unnecessary filter. + for (const auto &primary_property : primary_key) { + filter_exprs_for_removal_.insert(primary_property.first); + filters_.EraseFilter(primary_property.second); + } + EraseLabelFilters(node_symbol, prim_label); + std::vector<query::v2::Expression *> pk_expressions; + std::transform(primary_key.begin(), primary_key.end(), std::back_inserter(pk_expressions), + [](const auto &exp) { return exp.second.property_filter->value_; }); + return std::make_unique<ScanByPrimaryKey>(input, node_symbol, GetLabel(prim_label), pk_expressions); + } + } + auto found_index = FindBestLabelPropertyIndex(node_symbol, bound_symbols); if (found_index && // Use label+property index if we satisfy max_vertex_count. @@ -597,9 +656,7 @@ class IndexLookupRewriter final : public HierarchicalLogicalOperatorVisitor { filter_exprs_for_removal_.insert(found_index->filter.expression); } filters_.EraseFilter(found_index->filter); - std::vector<Expression *> removed_expressions; - filters_.EraseLabelFilter(node_symbol, found_index->label, &removed_expressions); - filter_exprs_for_removal_.insert(removed_expressions.begin(), removed_expressions.end()); + EraseLabelFilters(node_symbol, found_index->label); if (prop_filter.lower_bound_ || prop_filter.upper_bound_) { return std::make_unique<ScanAllByLabelPropertyRange>( input, node_symbol, GetLabel(found_index->label), GetProperty(prop_filter.property_), diff --git a/src/query/v2/plan/vertex_count_cache.hpp b/src/query/v2/plan/vertex_count_cache.hpp index e68ce1220..ae6cdad31 100644 --- a/src/query/v2/plan/vertex_count_cache.hpp +++ b/src/query/v2/plan/vertex_count_cache.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -12,14 +12,17 @@ /// @file #pragma once +#include <iterator> #include <optional> #include "query/v2/bindings/typed_value.hpp" +#include "query/v2/plan/preprocess.hpp" #include "query/v2/request_router.hpp" #include "storage/v3/conversions.hpp" #include "storage/v3/id_types.hpp" #include "storage/v3/property_value.hpp" #include "utils/bound.hpp" +#include "utils/exceptions.hpp" #include "utils/fnv.hpp" namespace memgraph::query::v2::plan { @@ -52,11 +55,16 @@ class VertexCountCache { return 1; } - // For now return true if label is primary label - bool LabelIndexExists(storage::v3::LabelId label) { return request_router_->IsPrimaryLabel(label); } + bool LabelIndexExists(storage::v3::LabelId label) { return PrimaryLabelExists(label); } + + bool PrimaryLabelExists(storage::v3::LabelId label) { return request_router_->IsPrimaryLabel(label); } bool LabelPropertyIndexExists(storage::v3::LabelId /*label*/, storage::v3::PropertyId /*property*/) { return false; } + const std::vector<memgraph::storage::v3::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) { + return request_router_->GetSchemaForLabel(label); + } + RequestRouterInterface *request_router_; }; diff --git a/src/query/v2/request_router.hpp b/src/query/v2/request_router.hpp index 96d51b05f..bf8c93566 100644 --- a/src/query/v2/request_router.hpp +++ b/src/query/v2/request_router.hpp @@ -24,14 +24,18 @@ #include <stdexcept> #include <thread> #include <unordered_map> +#include <variant> #include <vector> +#include <boost/uuid/uuid.hpp> + #include "coordinator/coordinator.hpp" #include "coordinator/coordinator_client.hpp" #include "coordinator/coordinator_rsm.hpp" #include "coordinator/shard_map.hpp" #include "io/address.hpp" #include "io/errors.hpp" +#include "io/local_transport/local_transport.hpp" #include "io/notifier.hpp" #include "io/rsm/raft.hpp" #include "io/rsm/rsm_client.hpp" @@ -114,6 +118,10 @@ class RequestRouterInterface { virtual std::optional<storage::v3::LabelId> MaybeNameToLabel(const std::string &name) const = 0; virtual bool IsPrimaryLabel(storage::v3::LabelId label) const = 0; virtual bool IsPrimaryKey(storage::v3::LabelId primary_label, storage::v3::PropertyId property) const = 0; + + virtual std::optional<std::pair<uint64_t, uint64_t>> AllocateInitialEdgeIds(io::Address coordinator_address) = 0; + virtual void InstallSimulatorTicker(std::function<bool()> tick_simulator) = 0; + virtual const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) const = 0; }; // TODO(kostasrim)rename this class template @@ -138,7 +146,7 @@ class RequestRouter : public RequestRouterInterface { ~RequestRouter() override {} - void InstallSimulatorTicker(std::function<bool()> tick_simulator) { + void InstallSimulatorTicker(std::function<bool()> tick_simulator) override { notifier_.InstallSimulatorTicker(tick_simulator); } @@ -232,12 +240,17 @@ class RequestRouter : public RequestRouterInterface { }) != schema_it->second.end(); } + const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId label) const override { + return shards_map_.schemas.at(label); + } + bool IsPrimaryLabel(storage::v3::LabelId label) const override { return shards_map_.label_spaces.contains(label); } // TODO(kostasrim) Simplify return result std::vector<VertexAccessor> ScanVertices(std::optional<std::string> label) override { // create requests - std::vector<ShardRequestState<msgs::ScanVerticesRequest>> requests_to_be_sent = RequestsForScanVertices(label); + auto requests_to_be_sent = RequestsForScanVertices(label); + spdlog::trace("created {} ScanVertices requests", requests_to_be_sent.size()); // begin all requests in parallel @@ -360,6 +373,7 @@ class RequestRouter : public RequestRouterInterface { } std::vector<msgs::GetPropertiesResultRow> GetProperties(msgs::GetPropertiesRequest requests) override { + requests.transaction_id = transaction_id_; // create requests std::vector<ShardRequestState<msgs::GetPropertiesRequest>> requests_to_be_sent = RequestsForGetProperties(std::move(requests)); @@ -708,6 +722,23 @@ class RequestRouter : public RequestRouterInterface { edge_types_.StoreMapping(std::move(id_to_name)); } + std::optional<std::pair<uint64_t, uint64_t>> AllocateInitialEdgeIds(io::Address coordinator_address) override { + coordinator::CoordinatorWriteRequests requests{coordinator::AllocateEdgeIdBatchRequest{.batch_size = 1000000}}; + + io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests> ww; + ww.operation = requests; + auto resp = + io_.template Request<io::rsm::WriteRequest<coordinator::CoordinatorWriteRequests>, + io::rsm::WriteResponse<coordinator::CoordinatorWriteResponses>>(coordinator_address, ww) + .Wait(); + if (resp.HasValue()) { + const auto alloc_edge_id_reps = + std::get<coordinator::AllocateEdgeIdBatchResponse>(resp.GetValue().message.write_return); + return std::make_pair(alloc_edge_id_reps.low, alloc_edge_id_reps.high); + } + return {}; + } + ShardMap shards_map_; storage::v3::NameIdMapper properties_; storage::v3::NameIdMapper edge_types_; @@ -719,4 +750,66 @@ class RequestRouter : public RequestRouterInterface { io::Notifier notifier_ = {}; // TODO(kostasrim) Add batch prefetching }; + +class RequestRouterFactory { + public: + RequestRouterFactory() = default; + RequestRouterFactory(const RequestRouterFactory &) = delete; + RequestRouterFactory &operator=(const RequestRouterFactory &) = delete; + RequestRouterFactory(RequestRouterFactory &&) = delete; + RequestRouterFactory &operator=(RequestRouterFactory &&) = delete; + + virtual ~RequestRouterFactory() = default; + + virtual std::unique_ptr<RequestRouterInterface> CreateRequestRouter( + const coordinator::Address &coordinator_address) const = 0; +}; + +class LocalRequestRouterFactory : public RequestRouterFactory { + using LocalTransportIo = io::Io<io::local_transport::LocalTransport>; + LocalTransportIo &io_; + + public: + explicit LocalRequestRouterFactory(LocalTransportIo &io) : io_(io) {} + + std::unique_ptr<RequestRouterInterface> CreateRequestRouter( + const coordinator::Address &coordinator_address) const override { + using TransportType = io::local_transport::LocalTransport; + + auto query_io = io_.ForkLocal(boost::uuids::uuid{boost::uuids::random_generator()()}); + auto local_transport_io = io_.ForkLocal(boost::uuids::uuid{boost::uuids::random_generator()()}); + + return std::make_unique<RequestRouter<TransportType>>( + coordinator::CoordinatorClient<TransportType>(query_io, coordinator_address, {coordinator_address}), + std::move(local_transport_io)); + } +}; + +class SimulatedRequestRouterFactory : public RequestRouterFactory { + io::simulator::Simulator *simulator_; + + public: + explicit SimulatedRequestRouterFactory(io::simulator::Simulator &simulator) : simulator_(&simulator) {} + + std::unique_ptr<RequestRouterInterface> CreateRequestRouter( + const coordinator::Address &coordinator_address) const override { + using TransportType = io::simulator::SimulatorTransport; + auto actual_transport_handle = simulator_->GetSimulatorHandle(); + + boost::uuids::uuid random_uuid; + io::Address unique_local_addr_query; + + // The simulated RR should not introduce stochastic behavior. + random_uuid = boost::uuids::uuid{3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + unique_local_addr_query = {.unique_id = boost::uuids::uuid{4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + + auto io = simulator_->Register(unique_local_addr_query); + auto query_io = io.ForkLocal(random_uuid); + + return std::make_unique<RequestRouter<TransportType>>( + coordinator::CoordinatorClient<TransportType>(query_io, coordinator_address, {coordinator_address}), + std::move(io)); + } +}; + } // namespace memgraph::query::v2 diff --git a/src/storage/v3/shard_rsm.cpp b/src/storage/v3/shard_rsm.cpp index 67a58e5cc..881796a70 100644 --- a/src/storage/v3/shard_rsm.cpp +++ b/src/storage/v3/shard_rsm.cpp @@ -544,6 +544,7 @@ msgs::ReadResponses ShardRsm::HandleRead(msgs::GetPropertiesRequest &&req) { } const auto *schema = shard_->GetSchema(shard_->PrimaryLabel()); MG_ASSERT(schema); + return CollectAllPropertiesFromAccessor(v_acc, view, *schema); } diff --git a/src/utils/event_counter.cpp b/src/utils/event_counter.cpp index 634b6eae2..2928027ee 100644 --- a/src/utils/event_counter.cpp +++ b/src/utils/event_counter.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -25,6 +25,7 @@ M(ScanAllByLabelPropertyValueOperator, "Number of times ScanAllByLabelPropertyValue operator was used.") \ M(ScanAllByLabelPropertyOperator, "Number of times ScanAllByLabelProperty operator was used.") \ M(ScanAllByIdOperator, "Number of times ScanAllById operator was used.") \ + M(ScanByPrimaryKeyOperator, "Number of times ScanByPrimaryKey operator was used.") \ M(ExpandOperator, "Number of times Expand operator was used.") \ M(ExpandVariableOperator, "Number of times ExpandVariable operator was used.") \ M(ConstructNamedPathOperator, "Number of times ConstructNamedPath operator was used.") \ diff --git a/tests/mgbench/client.cpp b/tests/mgbench/client.cpp index e4b63d477..000c199fa 100644 --- a/tests/mgbench/client.cpp +++ b/tests/mgbench/client.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/simulation/CMakeLists.txt b/tests/simulation/CMakeLists.txt index cd5fc0a4a..f3d4870f0 100644 --- a/tests/simulation/CMakeLists.txt +++ b/tests/simulation/CMakeLists.txt @@ -17,7 +17,7 @@ function(add_simulation_test test_cpp) # requires unique logical target names set_target_properties(${target_name} PROPERTIES OUTPUT_NAME ${exec_name}) - target_link_libraries(${target_name} mg-storage-v3 mg-communication mg-utils mg-io mg-io-simulator mg-coordinator mg-query-v2) + target_link_libraries(${target_name} mg-communication mg-utils mg-io mg-io-simulator mg-coordinator mg-query-v2 mg-storage-v3) target_link_libraries(${target_name} Boost::headers) target_link_libraries(${target_name} gtest gtest_main gmock rapidcheck rapidcheck_gtest) @@ -32,4 +32,5 @@ add_simulation_test(trial_query_storage/query_storage_test.cpp) add_simulation_test(sharded_map.cpp) add_simulation_test(shard_rsm.cpp) add_simulation_test(cluster_property_test.cpp) +add_simulation_test(cluster_property_test_cypher_queries.cpp) add_simulation_test(request_router.cpp) diff --git a/tests/simulation/cluster_property_test_cypher_queries.cpp b/tests/simulation/cluster_property_test_cypher_queries.cpp new file mode 100644 index 000000000..e35edc033 --- /dev/null +++ b/tests/simulation/cluster_property_test_cypher_queries.cpp @@ -0,0 +1,64 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +// This test serves as an example of a property-based model test. +// It generates a cluster configuration and a set of operations to +// apply against both the real system and a greatly simplified model. + +#include <chrono> + +#include <gtest/gtest.h> +#include <rapidcheck.h> +#include <rapidcheck/gtest.h> +#include <spdlog/cfg/env.h> + +#include "generated_operations.hpp" +#include "io/simulator/simulator_config.hpp" +#include "io/time.hpp" +#include "storage/v3/shard_manager.hpp" +#include "test_cluster.hpp" + +namespace memgraph::tests::simulation { + +using io::Duration; +using io::Time; +using io::simulator::SimulatorConfig; +using storage::v3::kMaximumCronInterval; + +RC_GTEST_PROP(RandomClusterConfig, HappyPath, (ClusterConfig cluster_config, NonEmptyOpVec ops, uint64_t rng_seed)) { + spdlog::cfg::load_env_levels(); + + SimulatorConfig sim_config{ + .drop_percent = 0, + .perform_timeouts = false, + .scramble_messages = true, + .rng_seed = rng_seed, + .start_time = Time::min(), + .abort_time = Time::max(), + }; + + std::vector<std::string> queries = {"CREATE (n:test_label{property_1: 0, property_2: 0});", "MATCH (n) RETURN n;"}; + + auto [sim_stats_1, latency_stats_1] = RunClusterSimulationWithQueries(sim_config, cluster_config, queries); + auto [sim_stats_2, latency_stats_2] = RunClusterSimulationWithQueries(sim_config, cluster_config, queries); + + if (latency_stats_1 != latency_stats_2) { + spdlog::error("simulator stats diverged across runs"); + spdlog::error("run 1 simulator stats: {}", sim_stats_1); + spdlog::error("run 2 simulator stats: {}", sim_stats_2); + spdlog::error("run 1 latency:\n{}", latency_stats_1.SummaryTable()); + spdlog::error("run 2 latency:\n{}", latency_stats_2.SummaryTable()); + RC_ASSERT(latency_stats_1 == latency_stats_2); + RC_ASSERT(sim_stats_1 == sim_stats_2); + } +} + +} // namespace memgraph::tests::simulation diff --git a/tests/simulation/request_router.cpp b/tests/simulation/request_router.cpp index 4248e7876..037674b66 100644 --- a/tests/simulation/request_router.cpp +++ b/tests/simulation/request_router.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/simulation/simulation_interpreter.hpp b/tests/simulation/simulation_interpreter.hpp new file mode 100644 index 000000000..e83980787 --- /dev/null +++ b/tests/simulation/simulation_interpreter.hpp @@ -0,0 +1,93 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "io/simulator/simulator_handle.hpp" +#include "machine_manager/machine_config.hpp" +#include "machine_manager/machine_manager.hpp" +#include "query/v2/config.hpp" +#include "query/v2/discard_value_stream.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/interpreter.hpp" +#include "query/v2/request_router.hpp" + +#include <string> +#include <vector> + +// TODO(gvolfing) +// -How to set up the entire raft cluster with the QE. Also provide abrstraction for that. +// -Pass an argument to the setup to determine, how many times the retry of a query should happen. + +namespace memgraph::io::simulator { + +class SimulatedInterpreter { + using ResultStream = query::v2::DiscardValueResultStream; + + public: + explicit SimulatedInterpreter(std::unique_ptr<query::v2::InterpreterContext> interpreter_context) + : interpreter_context_(std::move(interpreter_context)) { + interpreter_ = std::make_unique<memgraph::query::v2::Interpreter>(interpreter_context_.get()); + } + + SimulatedInterpreter(const SimulatedInterpreter &) = delete; + SimulatedInterpreter &operator=(const SimulatedInterpreter &) = delete; + SimulatedInterpreter(SimulatedInterpreter &&) = delete; + SimulatedInterpreter &operator=(SimulatedInterpreter &&) = delete; + ~SimulatedInterpreter() = default; + + void InstallSimulatorTicker(Simulator &simulator) { + interpreter_->InstallSimulatorTicker(simulator.GetSimulatorTickClosure()); + } + + std::vector<ResultStream> RunQueries(const std::vector<std::string> &queries) { + std::vector<ResultStream> results; + results.reserve(queries.size()); + + for (const auto &query : queries) { + results.emplace_back(RunQuery(query)); + } + return results; + } + + private: + ResultStream RunQuery(const std::string &query) { + ResultStream stream; + + std::map<std::string, memgraph::storage::v3::PropertyValue> params; + const std::string *username = nullptr; + + interpreter_->Prepare(query, params, username); + interpreter_->PullAll(&stream); + + return stream; + } + + std::unique_ptr<query::v2::InterpreterContext> interpreter_context_; + std::unique_ptr<query::v2::Interpreter> interpreter_; +}; + +SimulatedInterpreter SetUpInterpreter(Address coordinator_address, Simulator &simulator) { + auto rr_factory = std::make_unique<memgraph::query::v2::SimulatedRequestRouterFactory>(simulator); + + auto interpreter_context = std::make_unique<memgraph::query::v2::InterpreterContext>( + nullptr, + memgraph::query::v2::InterpreterConfig{.query = {.allow_load_csv = true}, + .execution_timeout_sec = 600, + .replication_replica_check_frequency = std::chrono::seconds(1), + .default_kafka_bootstrap_servers = "", + .default_pulsar_service_url = "", + .stream_transaction_conflict_retries = 30, + .stream_transaction_retry_interval = std::chrono::milliseconds(500)}, + std::filesystem::path("mg_data"), std::move(rr_factory), coordinator_address); + + return SimulatedInterpreter(std::move(interpreter_context)); +} + +} // namespace memgraph::io::simulator diff --git a/tests/simulation/test_cluster.hpp b/tests/simulation/test_cluster.hpp index 2e8bdf92f..f10e88e61 100644 --- a/tests/simulation/test_cluster.hpp +++ b/tests/simulation/test_cluster.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -36,6 +36,8 @@ #include "utils/print_helpers.hpp" #include "utils/variant_helpers.hpp" +#include "simulation_interpreter.hpp" + namespace memgraph::tests::simulation { using coordinator::Coordinator; @@ -279,4 +281,65 @@ std::pair<SimulatorStats, LatencyHistogramSummaries> RunClusterSimulation(const return std::make_pair(stats, histo); } +std::pair<SimulatorStats, LatencyHistogramSummaries> RunClusterSimulationWithQueries( + const SimulatorConfig &sim_config, const ClusterConfig &cluster_config, const std::vector<std::string> &queries) { + spdlog::info("========================== NEW SIMULATION =========================="); + + auto simulator = Simulator(sim_config); + + auto machine_1_addr = Address::TestAddress(1); + auto cli_addr = Address::TestAddress(2); + auto cli_addr_2 = Address::TestAddress(3); + + Io<SimulatorTransport> cli_io = simulator.Register(cli_addr); + Io<SimulatorTransport> cli_io_2 = simulator.Register(cli_addr_2); + + auto coordinator_addresses = std::vector{ + machine_1_addr, + }; + + ShardMap initialization_sm = TestShardMap(cluster_config.shards - 1, cluster_config.replication_factor); + + auto mm_1 = MkMm(simulator, coordinator_addresses, machine_1_addr, initialization_sm); + Address coordinator_address = mm_1.CoordinatorAddress(); + + auto mm_thread_1 = std::jthread(RunMachine, std::move(mm_1)); + simulator.IncrementServerCountAndWaitForQuiescentState(machine_1_addr); + + auto detach_on_error = DetachIfDropped{.handle = mm_thread_1}; + + // TODO(tyler) clarify addresses of coordinator etc... as it's a mess + + CoordinatorClient<SimulatorTransport> coordinator_client(cli_io, coordinator_address, {coordinator_address}); + WaitForShardsToInitialize(coordinator_client); + + auto simulated_interpreter = io::simulator::SetUpInterpreter(coordinator_address, simulator); + simulated_interpreter.InstallSimulatorTicker(simulator); + + auto query_results = simulated_interpreter.RunQueries(queries); + + // We have now completed our workload without failing any assertions, so we can + // disable detaching the worker thread, which will cause the mm_thread_1 jthread + // to be joined when this function returns. + detach_on_error.detach = false; + + simulator.ShutDown(); + + mm_thread_1.join(); + + SimulatorStats stats = simulator.Stats(); + + spdlog::info("total messages: {}", stats.total_messages); + spdlog::info("dropped messages: {}", stats.dropped_messages); + spdlog::info("timed out requests: {}", stats.timed_out_requests); + spdlog::info("total requests: {}", stats.total_requests); + spdlog::info("total responses: {}", stats.total_responses); + spdlog::info("simulator ticks: {}", stats.simulator_ticks); + + auto histo = cli_io_2.ResponseLatencies(); + + spdlog::info("========================== SUCCESS :) =========================="); + return std::make_pair(stats, histo); +} + } // namespace memgraph::tests::simulation diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index e7efb7473..ae747385b 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -93,6 +93,9 @@ target_link_libraries(${test_prefix}query_expression_evaluator mg-query) add_unit_test(query_plan.cpp) target_link_libraries(${test_prefix}query_plan mg-query) +add_unit_test(query_v2_plan.cpp) +target_link_libraries(${test_prefix}query_v2_plan mg-query-v2) + add_unit_test(query_plan_accumulate_aggregate.cpp) target_link_libraries(${test_prefix}query_plan_accumulate_aggregate mg-query) diff --git a/tests/unit/high_density_shard_create_scan.cpp b/tests/unit/high_density_shard_create_scan.cpp index 9fabf6ccc..4a90b98b2 100644 --- a/tests/unit/high_density_shard_create_scan.cpp +++ b/tests/unit/high_density_shard_create_scan.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/machine_manager.cpp b/tests/unit/machine_manager.cpp index 74b7d3863..6cc6a3ff1 100644 --- a/tests/unit/machine_manager.cpp +++ b/tests/unit/machine_manager.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source diff --git a/tests/unit/mock_helpers.hpp b/tests/unit/mock_helpers.hpp index 302b2ff55..15f264cac 100644 --- a/tests/unit/mock_helpers.hpp +++ b/tests/unit/mock_helpers.hpp @@ -42,6 +42,9 @@ class MockedRequestRouter : public RequestRouterInterface { MOCK_METHOD(std::optional<storage::v3::LabelId>, MaybeNameToLabel, (const std::string &), (const)); MOCK_METHOD(bool, IsPrimaryLabel, (storage::v3::LabelId), (const)); MOCK_METHOD(bool, IsPrimaryKey, (storage::v3::LabelId, storage::v3::PropertyId), (const)); + MOCK_METHOD((std::optional<std::pair<uint64_t, uint64_t>>), AllocateInitialEdgeIds, (io::Address)); + MOCK_METHOD(void, InstallSimulatorTicker, (std::function<bool()>)); + MOCK_METHOD(const std::vector<coordinator::SchemaProperty> &, GetSchemaForLabel, (storage::v3::LabelId), (const)); }; class MockedLogicalOperator : public plan::LogicalOperator { diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 903cacf12..784df5e21 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -501,6 +501,8 @@ auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vec storage.Create<memgraph::query::MapLiteral>( \ std::unordered_map<memgraph::query::PropertyIx, memgraph::query::Expression *>{__VA_ARGS__}) #define PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToProperty(property_name)) +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) #define PROPERTY_LOOKUP(...) memgraph::query::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) #define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::ParameterLookup>((token_position)) #define NEXPR(name, expr) storage.Create<memgraph::query::NamedExpression>((name), (expr)) diff --git a/tests/unit/query_plan_checker.hpp b/tests/unit/query_plan_checker.hpp index 335b6ab2b..7907f167f 100644 --- a/tests/unit/query_plan_checker.hpp +++ b/tests/unit/query_plan_checker.hpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -17,6 +17,7 @@ #include "query/plan/operator.hpp" #include "query/plan/planner.hpp" #include "query/plan/preprocess.hpp" +#include "query/v2/plan/operator.hpp" namespace memgraph::query::plan { @@ -90,7 +91,7 @@ class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { } PRE_VISIT(Unwind); PRE_VISIT(Distinct); - + bool PreVisit(Foreach &op) override { CheckOp(op); return false; @@ -336,6 +337,35 @@ class ExpectScanAllByLabelProperty : public OpChecker<ScanAllByLabelProperty> { memgraph::storage::PropertyId property_; }; +class ExpectScanByPrimaryKey : public OpChecker<v2::plan::ScanByPrimaryKey> { + public: + ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector<Expression *> &properties) + : label_(label), properties_(properties) {} + + void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + + bool primary_property_match = true; + for (const auto &expected_prop : properties_) { + bool has_match = false; + for (const auto &prop : scan_all.primary_key_) { + if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { + has_match = true; + } + } + if (!has_match) { + primary_property_match = false; + } + } + + EXPECT_TRUE(primary_property_match); + } + + private: + memgraph::storage::v3::LabelId label_; + std::vector<Expression *> properties_; +}; + class ExpectCartesian : public OpChecker<Cartesian> { public: ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, diff --git a/tests/unit/query_plan_checker_v2.hpp b/tests/unit/query_plan_checker_v2.hpp new file mode 100644 index 000000000..014f1c2bb --- /dev/null +++ b/tests/unit/query_plan_checker_v2.hpp @@ -0,0 +1,452 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "query/frontend/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/planner.hpp" +#include "query/v2/plan/preprocess.hpp" +#include "utils/exceptions.hpp" + +namespace memgraph::query::v2::plan { + +class BaseOpChecker { + public: + virtual ~BaseOpChecker() {} + + virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; +}; + +class PlanChecker : public virtual HierarchicalLogicalOperatorVisitor { + public: + using HierarchicalLogicalOperatorVisitor::PostVisit; + using HierarchicalLogicalOperatorVisitor::PreVisit; + using HierarchicalLogicalOperatorVisitor::Visit; + + PlanChecker(const std::list<std::unique_ptr<BaseOpChecker>> &checkers, const SymbolTable &symbol_table) + : symbol_table_(symbol_table) { + for (const auto &checker : checkers) checkers_.emplace_back(checker.get()); + } + + PlanChecker(const std::list<BaseOpChecker *> &checkers, const SymbolTable &symbol_table) + : checkers_(checkers), symbol_table_(symbol_table) {} + +#define PRE_VISIT(TOp) \ + bool PreVisit(TOp &op) override { \ + CheckOp(op); \ + return true; \ + } + +#define VISIT(TOp) \ + bool Visit(TOp &op) override { \ + CheckOp(op); \ + return true; \ + } + + PRE_VISIT(CreateNode); + PRE_VISIT(CreateExpand); + PRE_VISIT(Delete); + PRE_VISIT(ScanAll); + PRE_VISIT(ScanAllByLabel); + PRE_VISIT(ScanAllByLabelPropertyValue); + PRE_VISIT(ScanAllByLabelPropertyRange); + PRE_VISIT(ScanAllByLabelProperty); + PRE_VISIT(ScanByPrimaryKey); + PRE_VISIT(Expand); + PRE_VISIT(ExpandVariable); + PRE_VISIT(Filter); + PRE_VISIT(ConstructNamedPath); + PRE_VISIT(Produce); + PRE_VISIT(SetProperty); + PRE_VISIT(SetProperties); + PRE_VISIT(SetLabels); + PRE_VISIT(RemoveProperty); + PRE_VISIT(RemoveLabels); + PRE_VISIT(EdgeUniquenessFilter); + PRE_VISIT(Accumulate); + PRE_VISIT(Aggregate); + PRE_VISIT(Skip); + PRE_VISIT(Limit); + PRE_VISIT(OrderBy); + bool PreVisit(Merge &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } + bool PreVisit(Optional &op) override { + CheckOp(op); + op.input()->Accept(*this); + return false; + } + PRE_VISIT(Unwind); + PRE_VISIT(Distinct); + + bool PreVisit(Foreach &op) override { + CheckOp(op); + return false; + } + + bool Visit(Once &) override { + // Ignore checking Once, it is implicitly at the end. + return true; + } + + bool PreVisit(Cartesian &op) override { + CheckOp(op); + return false; + } + + PRE_VISIT(CallProcedure); + +#undef PRE_VISIT +#undef VISIT + + void CheckOp(LogicalOperator &op) { + ASSERT_FALSE(checkers_.empty()); + checkers_.back()->CheckOp(op, symbol_table_); + checkers_.pop_back(); + } + + std::list<BaseOpChecker *> checkers_; + const SymbolTable &symbol_table_; +}; + +template <class TOp> +class OpChecker : public BaseOpChecker { + public: + void CheckOp(LogicalOperator &op, const SymbolTable &symbol_table) override { + auto *expected_op = dynamic_cast<TOp *>(&op); + ASSERT_TRUE(expected_op) << "op is '" << op.GetTypeInfo().name << "' expected '" << TOp::kType.name << "'!"; + ExpectOp(*expected_op, symbol_table); + } + + virtual void ExpectOp(TOp &, const SymbolTable &) {} +}; + +using ExpectCreateNode = OpChecker<CreateNode>; +using ExpectCreateExpand = OpChecker<CreateExpand>; +using ExpectDelete = OpChecker<Delete>; +using ExpectScanAll = OpChecker<ScanAll>; +using ExpectScanAllByLabel = OpChecker<ScanAllByLabel>; +using ExpectExpand = OpChecker<Expand>; +using ExpectFilter = OpChecker<Filter>; +using ExpectConstructNamedPath = OpChecker<ConstructNamedPath>; +using ExpectProduce = OpChecker<Produce>; +using ExpectSetProperty = OpChecker<SetProperty>; +using ExpectSetProperties = OpChecker<SetProperties>; +using ExpectSetLabels = OpChecker<SetLabels>; +using ExpectRemoveProperty = OpChecker<RemoveProperty>; +using ExpectRemoveLabels = OpChecker<RemoveLabels>; +using ExpectEdgeUniquenessFilter = OpChecker<EdgeUniquenessFilter>; +using ExpectSkip = OpChecker<Skip>; +using ExpectLimit = OpChecker<Limit>; +using ExpectOrderBy = OpChecker<OrderBy>; +using ExpectUnwind = OpChecker<Unwind>; +using ExpectDistinct = OpChecker<Distinct>; + +class ExpectScanAllByLabelPropertyValue : public OpChecker<ScanAllByLabelPropertyValue> { + public: + ExpectScanAllByLabelPropertyValue(memgraph::storage::v3::LabelId label, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair, + memgraph::query::v2::Expression *expression) + : label_(label), property_(prop_pair.second), expression_(expression) {} + + void ExpectOp(ScanAllByLabelPropertyValue &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + EXPECT_EQ(scan_all.property_, property_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.expression_).hash_code(), typeid(expression_).hash_code()); + } + + private: + memgraph::storage::v3::LabelId label_; + memgraph::storage::v3::PropertyId property_; + memgraph::query::v2::Expression *expression_; +}; + +class ExpectScanByPrimaryKey : public OpChecker<v2::plan::ScanByPrimaryKey> { + public: + ExpectScanByPrimaryKey(memgraph::storage::v3::LabelId label, const std::vector<Expression *> &properties) + : label_(label), properties_(properties) {} + + void ExpectOp(v2::plan::ScanByPrimaryKey &scan_all, const SymbolTable &) override { + EXPECT_EQ(scan_all.label_, label_); + + bool primary_property_match = true; + for (const auto &expected_prop : properties_) { + bool has_match = false; + for (const auto &prop : scan_all.primary_key_) { + if (typeid(prop).hash_code() == typeid(expected_prop).hash_code()) { + has_match = true; + } + } + if (!has_match) { + primary_property_match = false; + } + } + + EXPECT_TRUE(primary_property_match); + } + + private: + memgraph::storage::v3::LabelId label_; + std::vector<Expression *> properties_; +}; + +class ExpectCartesian : public OpChecker<Cartesian> { + public: + ExpectCartesian(const std::list<std::unique_ptr<BaseOpChecker>> &left, + const std::list<std::unique_ptr<BaseOpChecker>> &right) + : left_(left), right_(right) {} + + void ExpectOp(Cartesian &op, const SymbolTable &symbol_table) override { + ASSERT_TRUE(op.left_op_); + PlanChecker left_checker(left_, symbol_table); + op.left_op_->Accept(left_checker); + ASSERT_TRUE(op.right_op_); + PlanChecker right_checker(right_, symbol_table); + op.right_op_->Accept(right_checker); + } + + private: + const std::list<std::unique_ptr<BaseOpChecker>> &left_; + const std::list<std::unique_ptr<BaseOpChecker>> &right_; +}; + +class ExpectCallProcedure : public OpChecker<CallProcedure> { + public: + ExpectCallProcedure(const std::string &name, const std::vector<memgraph::query::Expression *> &args, + const std::vector<std::string> &fields, const std::vector<Symbol> &result_syms) + : name_(name), args_(args), fields_(fields), result_syms_(result_syms) {} + + void ExpectOp(CallProcedure &op, const SymbolTable &symbol_table) override { + EXPECT_EQ(op.procedure_name_, name_); + EXPECT_EQ(op.arguments_.size(), args_.size()); + for (size_t i = 0; i < args_.size(); ++i) { + const auto *op_arg = op.arguments_[i]; + const auto *expected_arg = args_[i]; + // TODO: Proper expression equality + EXPECT_EQ(op_arg->GetTypeInfo(), expected_arg->GetTypeInfo()); + } + EXPECT_EQ(op.result_fields_, fields_); + EXPECT_EQ(op.result_symbols_, result_syms_); + } + + private: + std::string name_; + std::vector<memgraph::query::Expression *> args_; + std::vector<std::string> fields_; + std::vector<Symbol> result_syms_; +}; + +template <class T> +std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg) { + std::list<std::unique_ptr<BaseOpChecker>> l; + l.emplace_back(std::make_unique<T>(arg)); + return l; +} + +template <class T, class... Rest> +std::list<std::unique_ptr<BaseOpChecker>> MakeCheckers(T arg, Rest &&...rest) { + auto l = MakeCheckers(std::forward<Rest>(rest)...); + l.emplace_front(std::make_unique<T>(arg)); + return std::move(l); +} + +template <class TPlanner, class TDbAccessor> +TPlanner MakePlanner(TDbAccessor *dba, AstStorage &storage, SymbolTable &symbol_table, CypherQuery *query) { + auto planning_context = MakePlanningContext(&storage, &symbol_table, query, dba); + auto query_parts = CollectQueryParts(symbol_table, storage, query); + auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; + return TPlanner(single_query_parts, planning_context); +} + +class FakeDistributedDbAccessor { + public: + int64_t VerticesCount(memgraph::storage::v3::LabelId label) const { + auto found = label_index_.find(label); + if (found != label_index_.end()) return found->second; + return 0; + } + + int64_t VerticesCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property) const { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + return std::get<2>(index); + } + } + return 0; + } + + bool LabelIndexExists(memgraph::storage::v3::LabelId label) const { + throw utils::NotYetImplemented("Label indicies are yet to be implemented."); + } + + bool LabelPropertyIndexExists(memgraph::storage::v3::LabelId label, + memgraph::storage::v3::PropertyId property) const { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + return true; + } + } + return false; + } + + bool PrimaryLabelExists(storage::v3::LabelId label) { return label_index_.find(label) != label_index_.end(); } + + void SetIndexCount(memgraph::storage::v3::LabelId label, int64_t count) { label_index_[label] = count; } + + void SetIndexCount(memgraph::storage::v3::LabelId label, memgraph::storage::v3::PropertyId property, int64_t count) { + for (auto &index : label_property_index_) { + if (std::get<0>(index) == label && std::get<1>(index) == property) { + std::get<2>(index) = count; + return; + } + } + label_property_index_.emplace_back(label, property, count); + } + + memgraph::storage::v3::LabelId NameToLabel(const std::string &name) { + auto found = primary_labels_.find(name); + if (found != primary_labels_.end()) return found->second; + return primary_labels_.emplace(name, memgraph::storage::v3::LabelId::FromUint(primary_labels_.size())) + .first->second; + } + + memgraph::storage::v3::LabelId Label(const std::string &name) { return NameToLabel(name); } + + memgraph::storage::v3::EdgeTypeId NameToEdgeType(const std::string &name) { + auto found = edge_types_.find(name); + if (found != edge_types_.end()) return found->second; + return edge_types_.emplace(name, memgraph::storage::v3::EdgeTypeId::FromUint(edge_types_.size())).first->second; + } + + memgraph::storage::v3::PropertyId NameToPrimaryProperty(const std::string &name) { + auto found = primary_properties_.find(name); + if (found != primary_properties_.end()) return found->second; + return primary_properties_.emplace(name, memgraph::storage::v3::PropertyId::FromUint(primary_properties_.size())) + .first->second; + } + + memgraph::storage::v3::PropertyId NameToSecondaryProperty(const std::string &name) { + auto found = secondary_properties_.find(name); + if (found != secondary_properties_.end()) return found->second; + return secondary_properties_ + .emplace(name, memgraph::storage::v3::PropertyId::FromUint(secondary_properties_.size())) + .first->second; + } + + memgraph::storage::v3::PropertyId PrimaryProperty(const std::string &name) { return NameToPrimaryProperty(name); } + memgraph::storage::v3::PropertyId SecondaryProperty(const std::string &name) { return NameToSecondaryProperty(name); } + + std::string PrimaryPropertyToName(memgraph::storage::v3::PropertyId property) const { + for (const auto &kv : primary_properties_) { + if (kv.second == property) return kv.first; + } + LOG_FATAL("Unable to find primary property name"); + } + + std::string SecondaryPropertyToName(memgraph::storage::v3::PropertyId property) const { + for (const auto &kv : secondary_properties_) { + if (kv.second == property) return kv.first; + } + LOG_FATAL("Unable to find secondary property name"); + } + + std::string PrimaryPropertyName(memgraph::storage::v3::PropertyId property) const { + return PrimaryPropertyToName(property); + } + std::string SecondaryPropertyName(memgraph::storage::v3::PropertyId property) const { + return SecondaryPropertyToName(property); + } + + memgraph::storage::v3::PropertyId NameToProperty(const std::string &name) { + auto find_in_prim_properties = primary_properties_.find(name); + if (find_in_prim_properties != primary_properties_.end()) { + return find_in_prim_properties->second; + } + auto find_in_secondary_properties = secondary_properties_.find(name); + if (find_in_secondary_properties != secondary_properties_.end()) { + return find_in_secondary_properties->second; + } + + LOG_FATAL("The property does not exist as a primary or a secondary property."); + return memgraph::storage::v3::PropertyId::FromUint(0); + } + + std::vector<memgraph::storage::v3::SchemaProperty> GetSchemaForLabel(storage::v3::LabelId label) { + auto schema_properties = schemas_.at(label); + std::vector<memgraph::storage::v3::SchemaProperty> ret; + std::transform(schema_properties.begin(), schema_properties.end(), std::back_inserter(ret), [](const auto &prop) { + memgraph::storage::v3::SchemaProperty schema_prop = { + .property_id = prop, + // This should not be hardcoded, but for testing purposes it will suffice. + .type = memgraph::common::SchemaType::INT}; + + return schema_prop; + }); + return ret; + } + + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> ExtractPrimaryKey( + storage::v3::LabelId label, std::vector<query::v2::plan::FilterInfo> property_filters) { + MG_ASSERT(schemas_.contains(label), + "You did not specify the Schema for this label! Use FakeDistributedDbAccessor::CreateSchema(...)."); + + std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>> pk; + const auto schema = GetSchemaPropertiesForLabel(label); + + std::vector<storage::v3::PropertyId> schema_properties; + schema_properties.reserve(schema.size()); + + std::transform(schema.begin(), schema.end(), std::back_inserter(schema_properties), + [](const auto &schema_elem) { return schema_elem; }); + + for (const auto &property_filter : property_filters) { + const auto &property_id = NameToProperty(property_filter.property_filter->property_.name); + if (std::find(schema_properties.begin(), schema_properties.end(), property_id) != schema_properties.end()) { + pk.emplace_back(std::make_pair(property_filter.expression, property_filter)); + } + } + + return pk.size() == schema_properties.size() + ? pk + : std::vector<std::pair<query::v2::Expression *, query::v2::plan::FilterInfo>>{}; + } + + std::vector<memgraph::storage::v3::PropertyId> GetSchemaPropertiesForLabel(storage::v3::LabelId label) { + return schemas_.at(label); + } + + void CreateSchema(const memgraph::storage::v3::LabelId primary_label, + const std::vector<memgraph::storage::v3::PropertyId> &schemas_types) { + MG_ASSERT(!schemas_.contains(primary_label), "You already created the schema for this label!"); + schemas_.emplace(primary_label, schemas_types); + } + + private: + std::unordered_map<std::string, memgraph::storage::v3::LabelId> primary_labels_; + std::unordered_map<std::string, memgraph::storage::v3::LabelId> secondary_labels_; + std::unordered_map<std::string, memgraph::storage::v3::EdgeTypeId> edge_types_; + std::unordered_map<std::string, memgraph::storage::v3::PropertyId> primary_properties_; + std::unordered_map<std::string, memgraph::storage::v3::PropertyId> secondary_properties_; + + std::unordered_map<memgraph::storage::v3::LabelId, int64_t> label_index_; + std::vector<std::tuple<memgraph::storage::v3::LabelId, memgraph::storage::v3::PropertyId, int64_t>> + label_property_index_; + + std::unordered_map<memgraph::storage::v3::LabelId, std::vector<memgraph::storage::v3::PropertyId>> schemas_; +}; + +} // namespace memgraph::query::v2::plan diff --git a/tests/unit/query_v2_common.hpp b/tests/unit/query_v2_common.hpp new file mode 100644 index 000000000..1b9d6807c --- /dev/null +++ b/tests/unit/query_v2_common.hpp @@ -0,0 +1,603 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +/// @file +/// This file provides macros for easier construction of openCypher query AST. +/// The usage of macros is very similar to how one would write openCypher. For +/// example: +/// +/// AstStorage storage; // Macros rely on storage being in scope. +/// // PROPERTY_LOOKUP and PROPERTY_PAIR macros +/// // rely on a DbAccessor *reference* named dba. +/// database::GraphDb db; +/// auto dba_ptr = db.Access(); +/// auto &dba = *dba_ptr; +/// +/// QUERY(MATCH(PATTERN(NODE("n"), EDGE("e"), NODE("m"))), +/// WHERE(LESS(PROPERTY_LOOKUP("e", edge_prop), LITERAL(3))), +/// RETURN(SUM(PROPERTY_LOOKUP("m", prop)), AS("sum"), +/// ORDER_BY(IDENT("sum")), +/// SKIP(ADD(LITERAL(1), LITERAL(2))))); +/// +/// Each of the macros is accompanied by a function. The functions use overload +/// resolution and template magic to provide a type safe way of constructing +/// queries. Although the functions can be used by themselves, it is more +/// convenient to use the macros. + +#pragma once + +#include <map> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "query/frontend/ast/pretty_print.hpp" // not sure if that is ok... +#include "query/v2/frontend/ast/ast.hpp" +#include "storage/v3/id_types.hpp" +#include "utils/string.hpp" + +#include "query/v2/frontend/ast/ast.hpp" + +namespace memgraph::query::v2 { + +namespace test_common { + +auto ToIntList(const TypedValue &t) { + std::vector<int64_t> list; + for (auto x : t.ValueList()) { + list.push_back(x.ValueInt()); + } + return list; +}; + +auto ToIntMap(const TypedValue &t) { + std::map<std::string, int64_t> map; + for (const auto &kv : t.ValueMap()) map.emplace(kv.first, kv.second.ValueInt()); + return map; +}; + +std::string ToString(Expression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +std::string ToString(NamedExpression *expr) { + std::ostringstream ss; + // PrintExpression(expr, &ss); + return ss.str(); +} + +// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions, +// so that they can be used to resolve function calls. +struct OrderBy { + std::vector<SortItem> expressions; +}; + +struct Skip { + Expression *expression = nullptr; +}; + +struct Limit { + Expression *expression = nullptr; +}; + +struct OnMatch { + std::vector<Clause *> set; +}; +struct OnCreate { + std::vector<Clause *> set; +}; + +// Helper functions for filling the OrderBy with expressions. +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering = Ordering::ASC) { + order_by.expressions.push_back({ordering, expression}); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, Ordering ordering, T... rest) { + FillOrderBy(order_by, expression, ordering); + FillOrderBy(order_by, rest...); +} +template <class... T> +auto FillOrderBy(OrderBy &order_by, Expression *expression, T... rest) { + FillOrderBy(order_by, expression); + FillOrderBy(order_by, rest...); +} + +/// Create OrderBy expressions. +/// +/// The supported combination of arguments is: (Expression, [Ordering])+ +/// Since the Ordering is optional, by default it is ascending. +template <class... T> +auto GetOrderBy(T... exprs) { + OrderBy order_by; + FillOrderBy(order_by, exprs...); + return order_by; +} + +/// Create PropertyLookup with given name and property. +/// +/// Name is used to create the Identifier which is used for property lookup. +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, const std::string &name, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), + storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, + memgraph::storage::v3::PropertyId property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(dba.PropertyToName(property))); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &dba, Expression *expr, const std::string &property) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(property)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, const std::string &name, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(storage.Create<Identifier>(name), storage.GetPropertyIx(prop_pair.first)); +} + +template <class TDbAccessor> +auto GetPropertyLookup(AstStorage &storage, TDbAccessor &, Expression *expr, + const std::pair<std::string, memgraph::storage::v3::PropertyId> &prop_pair) { + return storage.Create<PropertyLookup>(expr, storage.GetPropertyIx(prop_pair.first)); +} + +/// Create an EdgeAtom with given name, direction and edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdge(AstStorage &storage, const std::string &name, EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + return storage.Create<EdgeAtom>(storage.Create<Identifier>(name), EdgeAtom::Type::SINGLE, dir, types); +} + +/// Create a variable length expansion EdgeAtom with given name, direction and +/// edge_type. +/// +/// Name is used to create the Identifier which is assigned to the edge. +auto GetEdgeVariable(AstStorage &storage, const std::string &name, EdgeAtom::Type type = EdgeAtom::Type::DEPTH_FIRST, + EdgeAtom::Direction dir = EdgeAtom::Direction::BOTH, + const std::vector<std::string> &edge_types = {}, Identifier *flambda_inner_edge = nullptr, + Identifier *flambda_inner_node = nullptr, Identifier *wlambda_inner_edge = nullptr, + Identifier *wlambda_inner_node = nullptr, Expression *wlambda_expression = nullptr, + Identifier *total_weight = nullptr) { + std::vector<EdgeTypeIx> types; + types.reserve(edge_types.size()); + for (const auto &type : edge_types) { + types.push_back(storage.GetEdgeTypeIx(type)); + } + auto r_val = storage.Create<EdgeAtom>(storage.Create<Identifier>(name), type, dir, types); + + r_val->filter_lambda_.inner_edge = + flambda_inner_edge ? flambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->filter_lambda_.inner_node = + flambda_inner_node ? flambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + + if (type == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) { + r_val->weight_lambda_.inner_edge = + wlambda_inner_edge ? wlambda_inner_edge : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.inner_node = + wlambda_inner_node ? wlambda_inner_node : storage.Create<Identifier>(memgraph::utils::RandomString(20)); + r_val->weight_lambda_.expression = + wlambda_expression ? wlambda_expression : storage.Create<memgraph::query::v2::PrimitiveLiteral>(1); + + r_val->total_weight_ = total_weight; + } + + return r_val; +} + +/// Create a NodeAtom with given name and label. +/// +/// Name is used to create the Identifier which is assigned to the node. +auto GetNode(AstStorage &storage, const std::string &name, std::optional<std::string> label = std::nullopt) { + auto node = storage.Create<NodeAtom>(storage.Create<Identifier>(name)); + if (label) node->labels_.emplace_back(storage.GetLabelIx(*label)); + return node; +} + +/// Create a Pattern with given atoms. +auto GetPattern(AstStorage &storage, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(memgraph::utils::RandomString(20), false); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// Create a Pattern with given name and atoms. +auto GetPattern(AstStorage &storage, const std::string &name, std::vector<PatternAtom *> atoms) { + auto pattern = storage.Create<Pattern>(); + pattern->identifier_ = storage.Create<Identifier>(name, true); + pattern->atoms_.insert(pattern->atoms_.begin(), atoms.begin(), atoms.end()); + return pattern; +} + +/// This function fills an AST node which with given patterns. +/// +/// The function is most commonly used to create Match and Create clauses. +template <class TWithPatterns> +auto GetWithPatterns(TWithPatterns *with_patterns, std::vector<Pattern *> patterns) { + with_patterns->patterns_.insert(with_patterns->patterns_.begin(), patterns.begin(), patterns.end()); + return with_patterns; +} + +/// Create a query with given clauses. + +auto GetSingleQuery(SingleQuery *single_query, Clause *clause) { + single_query->clauses_.emplace_back(clause); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return single_query; +} +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return single_query; +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Match *match, Where *where, T *...clauses) { + match->where_ = where; + single_query->clauses_.emplace_back(match); + return GetSingleQuery(single_query, clauses...); +} +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, With *with, Where *where, T *...clauses) { + with->where_ = where; + single_query->clauses_.emplace_back(with); + return GetSingleQuery(single_query, clauses...); +} + +template <class... T> +auto GetSingleQuery(SingleQuery *single_query, Clause *clause, T *...clauses) { + single_query->clauses_.emplace_back(clause); + return GetSingleQuery(single_query, clauses...); +} + +auto GetCypherUnion(CypherUnion *cypher_union, SingleQuery *single_query) { + cypher_union->single_query_ = single_query; + return cypher_union; +} + +auto GetQuery(AstStorage &storage, SingleQuery *single_query) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + return query; +} + +template <class... T> +auto GetQuery(AstStorage &storage, SingleQuery *single_query, T *...cypher_unions) { + auto *query = storage.Create<CypherQuery>(); + query->single_query_ = single_query; + query->cypher_unions_ = std::vector<CypherUnion *>{cypher_unions...}; + return query; +} + +// Helper functions for constructing RETURN and WITH clauses. +void FillReturnBody(AstStorage &, ReturnBody &body, NamedExpression *named_expr) { + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name) { + if (name == "*") { + body.all_identifiers = true; + } else { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + } +} +void FillReturnBody(AstStorage &, ReturnBody &body, Limit limit) { body.limit = limit.expression; } +void FillReturnBody(AstStorage &, ReturnBody &body, Skip skip, Limit limit = Limit{}) { + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, OrderBy order_by, Skip skip, Limit limit = Limit{}) { + body.order_by = order_by.expressions; + body.skip = skip.expression; + body.limit = limit.expression; +} +void FillReturnBody(AstStorage &, ReturnBody &body, Expression *expr, NamedExpression *named_expr) { + // This overload supports `RETURN(expr, AS(name))` construct, since + // NamedExpression does not inherit Expression. + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); +} +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, Expression *expr, NamedExpression *named_expr, T... rest) { + named_expr->expression_ = expr; + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, NamedExpression *named_expr, T... rest) { + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, NamedExpression *named_expr, + T... rest) { + named_expr->expression_ = storage.Create<memgraph::query::v2::Identifier>(name); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} +template <class... T> +void FillReturnBody(AstStorage &storage, ReturnBody &body, const std::string &name, T... rest) { + auto *ident = storage.Create<memgraph::query::v2::Identifier>(name); + auto *named_expr = storage.Create<memgraph::query::v2::NamedExpression>(name, ident); + body.named_expressions.emplace_back(named_expr); + FillReturnBody(storage, body, rest...); +} + +/// Create the return clause with given expressions. +/// +/// The supported expression combination of arguments is: +/// +/// (String | NamedExpression | (Expression NamedExpression))+ +/// [OrderBy] [Skip] [Limit] +/// +/// When the pair (Expression NamedExpression) is given, the Expression will be +/// moved inside the NamedExpression. This is done, so that the constructs like +/// RETURN(expr, AS("name"), ...) are supported. Taking a String is a shorthand +/// for RETURN(IDENT(string), AS(string), ....). +/// +/// @sa GetWith +template <class... T> +auto GetReturn(AstStorage &storage, bool distinct, T... exprs) { + auto ret = storage.Create<Return>(); + ret->body_.distinct = distinct; + FillReturnBody(storage, ret->body_, exprs...); + return ret; +} + +/// Create the with clause with given expressions. +/// +/// The supported expression combination is the same as for @c GetReturn. +/// +/// @sa GetReturn +template <class... T> +auto GetWith(AstStorage &storage, bool distinct, T... exprs) { + auto with = storage.Create<With>(); + with->body_.distinct = distinct; + FillReturnBody(storage, with->body_, exprs...); + return with; +} + +/// Create the UNWIND clause with given named expression. +auto GetUnwind(AstStorage &storage, NamedExpression *named_expr) { + return storage.Create<memgraph::query::v2::Unwind>(named_expr); +} +auto GetUnwind(AstStorage &storage, Expression *expr, NamedExpression *as) { + as->expression_ = expr; + return GetUnwind(storage, as); +} + +/// Create the delete clause with given named expressions. +auto GetDelete(AstStorage &storage, std::vector<Expression *> exprs, bool detach = false) { + auto del = storage.Create<Delete>(); + del->expressions_.insert(del->expressions_.begin(), exprs.begin(), exprs.end()); + del->detach_ = detach; + return del; +} + +/// Create a set property clause for given property lookup and the right hand +/// side expression. +auto GetSet(AstStorage &storage, PropertyLookup *prop_lookup, Expression *expr) { + return storage.Create<SetProperty>(prop_lookup, expr); +} + +/// Create a set properties clause for given identifier name and the right hand +/// side expression. +auto GetSet(AstStorage &storage, const std::string &name, Expression *expr, bool update = false) { + return storage.Create<SetProperties>(storage.Create<Identifier>(name), expr, update); +} + +/// Create a set labels clause for given identifier name and labels. +auto GetSet(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<SetLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a remove property clause for given property lookup +auto GetRemove(AstStorage &storage, PropertyLookup *prop_lookup) { return storage.Create<RemoveProperty>(prop_lookup); } + +/// Create a remove labels clause for given identifier name and labels. +auto GetRemove(AstStorage &storage, const std::string &name, std::vector<std::string> label_names) { + std::vector<LabelIx> labels; + labels.reserve(label_names.size()); + for (const auto &label : label_names) { + labels.push_back(storage.GetLabelIx(label)); + } + return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels); +} + +/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate +/// parts. +auto GetMerge(AstStorage &storage, Pattern *pattern, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_create_ = on_create.set; + return merge; +} +auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match, OnCreate on_create = OnCreate{}) { + auto *merge = storage.Create<memgraph::query::v2::Merge>(); + merge->pattern_ = pattern; + merge->on_match_ = on_match.set; + merge->on_create_ = on_create.set; + return merge; +} + +auto GetCallProcedure(AstStorage &storage, std::string procedure_name, + std::vector<memgraph::query::v2::Expression *> arguments = {}) { + auto *call_procedure = storage.Create<memgraph::query::v2::CallProcedure>(); + call_procedure->procedure_name_ = std::move(procedure_name); + call_procedure->arguments_ = std::move(arguments); + return call_procedure; +} + +/// Create the FOREACH clause with given named expression. +auto GetForeach(AstStorage &storage, NamedExpression *named_expr, const std::vector<query::v2::Clause *> &clauses) { + return storage.Create<query::v2::Foreach>(named_expr, clauses); +} + +} // namespace test_common + +} // namespace memgraph::query::v2 + +/// All the following macros implicitly pass `storage` variable to functions. +/// You need to have `AstStorage storage;` somewhere in scope to use them. +/// Refer to function documentation to see what the macro does. +/// +/// Example usage: +/// +/// // Create MATCH (n) -[r]- (m) RETURN m AS new_name +/// AstStorage storage; +/// auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))), +/// RETURN(NEXPR("new_name"), IDENT("m"))); +#define NODE(...) memgraph::expr::test_common::GetNode(storage, __VA_ARGS__) +#define EDGE(...) memgraph::expr::test_common::GetEdge(storage, __VA_ARGS__) +#define EDGE_VARIABLE(...) memgraph::expr::test_common::GetEdgeVariable(storage, __VA_ARGS__) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define PATTERN(...) memgraph::expr::test_common::GetPattern(storage, {__VA_ARGS__}) +#define NAMED_PATTERN(name, ...) memgraph::expr::test_common::GetPattern(storage, name, {__VA_ARGS__}) +#define OPTIONAL_MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(true), {__VA_ARGS__}) +#define MATCH(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Match>(), {__VA_ARGS__}) +#define WHERE(expr) storage.Create<memgraph::query::v2::Where>((expr)) +#define CREATE(...) \ + memgraph::expr::test_common::GetWithPatterns(storage.Create<memgraph::query::v2::Create>(), {__VA_ARGS__}) +#define IDENT(...) storage.Create<memgraph::query::v2::Identifier>(__VA_ARGS__) +#define LITERAL(val) storage.Create<memgraph::query::v2::PrimitiveLiteral>((val)) +#define LIST(...) \ + storage.Create<memgraph::query::v2::ListLiteral>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define MAP(...) \ + storage.Create<memgraph::query::v2::MapLiteral>( \ + std::unordered_map<memgraph::query::v2::PropertyIx, memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define PROPERTY_PAIR(property_name) \ + std::make_pair(property_name, dba.NameToProperty(property_name)) // This one might not be needed at all +#define PRIMARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToPrimaryProperty(property_name)) +#define SECONDARY_PROPERTY_PAIR(property_name) std::make_pair(property_name, dba.NameToSecondaryProperty(property_name)) +#define PROPERTY_LOOKUP(...) memgraph::expr::test_common::GetPropertyLookup(storage, dba, __VA_ARGS__) +#define PARAMETER_LOOKUP(token_position) storage.Create<memgraph::query::v2::ParameterLookup>((token_position)) +#define NEXPR(name, expr) storage.Create<memgraph::query::v2::NamedExpression>((name), (expr)) +// AS is alternative to NEXPR which does not initialize NamedExpression with +// Expression. It should be used with RETURN or WITH. For example: +// RETURN(IDENT("n"), AS("n")) vs. RETURN(NEXPR("n", IDENT("n"))). +#define AS(name) storage.Create<memgraph::query::v2::NamedExpression>((name)) +#define RETURN(...) memgraph::expr::test_common::GetReturn(storage, false, __VA_ARGS__) +#define WITH(...) memgraph::expr::test_common::GetWith(storage, false, __VA_ARGS__) +#define RETURN_DISTINCT(...) memgraph::expr::test_common::GetReturn(storage, true, __VA_ARGS__) +#define WITH_DISTINCT(...) memgraph::expr::test_common::GetWith(storage, true, __VA_ARGS__) +#define UNWIND(...) memgraph::expr::test_common::GetUnwind(storage, __VA_ARGS__) +#define ORDER_BY(...) memgraph::expr::test_common::GetOrderBy(__VA_ARGS__) +#define SKIP(expr) \ + memgraph::expr::test_common::Skip { (expr) } +#define LIMIT(expr) \ + memgraph::expr::test_common::Limit { (expr) } +#define DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}) +#define DETACH_DELETE(...) memgraph::expr::test_common::GetDelete(storage, {__VA_ARGS__}, true) +#define SET(...) memgraph::expr::test_common::GetSet(storage, __VA_ARGS__) +#define REMOVE(...) memgraph::expr::test_common::GetRemove(storage, __VA_ARGS__) +#define MERGE(...) memgraph::expr::test_common::GetMerge(storage, __VA_ARGS__) +#define ON_MATCH(...) \ + memgraph::expr::test_common::OnMatch { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define ON_CREATE(...) \ + memgraph::expr::test_common::OnCreate { \ + std::vector<memgraph::query::v2::Clause *> { __VA_ARGS__ } \ + } +#define CREATE_INDEX_ON(label, property) \ + storage.Create<memgraph::query::v2::IndexQuery>(memgraph::query::v2::IndexQuery::Action::CREATE, (label), \ + std::vector<memgraph::query::v2::PropertyIx>{(property)}) +#define QUERY(...) memgraph::expr::test_common::GetQuery(storage, __VA_ARGS__) +#define SINGLE_QUERY(...) memgraph::expr::test_common::GetSingleQuery(storage.Create<SingleQuery>(), __VA_ARGS__) +#define UNION(...) memgraph::expr::test_common::GetCypherUnion(storage.Create<CypherUnion>(true), __VA_ARGS__) +#define UNION_ALL(...) memgraph::expr::test_common::GetCypherUnion(storage.Create<CypherUnion>(false), __VA_ARGS__) +#define FOREACH(...) memgraph::expr::test_common::GetForeach(storage, __VA_ARGS__) +// Various operators +#define NOT(expr) storage.Create<memgraph::query::v2::NotOperator>((expr)) +#define UPLUS(expr) storage.Create<memgraph::query::v2::UnaryPlusOperator>((expr)) +#define UMINUS(expr) storage.Create<memgraph::query::v2::UnaryMinusOperator>((expr)) +#define IS_NULL(expr) storage.Create<memgraph::query::v2::IsNullOperator>((expr)) +#define ADD(expr1, expr2) storage.Create<memgraph::query::v2::AdditionOperator>((expr1), (expr2)) +#define LESS(expr1, expr2) storage.Create<memgraph::query::v2::LessOperator>((expr1), (expr2)) +#define LESS_EQ(expr1, expr2) storage.Create<memgraph::query::v2::LessEqualOperator>((expr1), (expr2)) +#define GREATER(expr1, expr2) storage.Create<memgraph::query::v2::GreaterOperator>((expr1), (expr2)) +#define GREATER_EQ(expr1, expr2) storage.Create<memgraph::query::v2::GreaterEqualOperator>((expr1), (expr2)) +#define SUM(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::SUM) +#define COUNT(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COUNT) +#define AVG(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::AVG) +#define COLLECT_LIST(expr) \ + storage.Create<memgraph::query::v2::Aggregation>((expr), nullptr, memgraph::query::v2::Aggregation::Op::COLLECT_LIST) +#define EQ(expr1, expr2) storage.Create<memgraph::query::v2::EqualOperator>((expr1), (expr2)) +#define NEQ(expr1, expr2) storage.Create<memgraph::query::v2::NotEqualOperator>((expr1), (expr2)) +#define AND(expr1, expr2) storage.Create<memgraph::query::v2::AndOperator>((expr1), (expr2)) +#define OR(expr1, expr2) storage.Create<memgraph::query::v2::OrOperator>((expr1), (expr2)) +#define IN_LIST(expr1, expr2) storage.Create<memgraph::query::v2::InListOperator>((expr1), (expr2)) +#define IF(cond, then, else) storage.Create<memgraph::query::v2::IfOperator>((cond), (then), (else)) +// Function call +#define FN(function_name, ...) \ + storage.Create<memgraph::query::v2::Function>(memgraph::utils::ToUpperCase(function_name), \ + std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +// List slicing +#define SLICE(list, lower_bound, upper_bound) \ + storage.Create<memgraph::query::v2::ListSlicingOperator>(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create<memgraph::query::v2::All>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define SINGLE(variable, list, where) \ + storage.Create<memgraph::query::v2::Single>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define ANY(variable, list, where) \ + storage.Create<memgraph::query::v2::Any>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define NONE(variable, list, where) \ + storage.Create<memgraph::query::v2::None>(storage.Create<memgraph::query::v2::Identifier>(variable), list, where) +#define REDUCE(accumulator, initializer, variable, list, expr) \ + storage.Create<memgraph::query::v2::Reduce>(storage.Create<memgraph::query::v2::Identifier>(accumulator), \ + initializer, storage.Create<memgraph::query::v2::Identifier>(variable), \ + list, expr) +#define COALESCE(...) \ + storage.Create<memgraph::query::v2::Coalesce>(std::vector<memgraph::query::v2::Expression *>{__VA_ARGS__}) +#define EXTRACT(variable, list, expr) \ + storage.Create<memgraph::query::v2::Extract>(storage.Create<memgraph::query::v2::Identifier>(variable), list, expr) +#define AUTH_QUERY(action, user, role, user_or_role, password, privileges) \ + storage.Create<memgraph::query::v2::AuthQuery>((action), (user), (role), (user_or_role), password, (privileges)) +#define DROP_USER(usernames) storage.Create<memgraph::query::v2::DropUser>((usernames)) +#define CALL_PROCEDURE(...) memgraph::query::v2::test_common::GetCallProcedure(storage, __VA_ARGS__) diff --git a/tests/unit/query_v2_expression_evaluator.cpp b/tests/unit/query_v2_expression_evaluator.cpp index 5e91d0d5a..0000e62b2 100644 --- a/tests/unit/query_v2_expression_evaluator.cpp +++ b/tests/unit/query_v2_expression_evaluator.cpp @@ -1,4 +1,4 @@ -// Copyright 2022 Memgraph Ltd. +// Copyright 2023 Memgraph Ltd. // // Use of this software is governed by the Business Source License // included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source @@ -125,6 +125,16 @@ class MockedRequestRouter : public RequestRouterInterface { bool IsPrimaryKey(LabelId primary_label, PropertyId property) const override { return true; } + std::optional<std::pair<uint64_t, uint64_t>> AllocateInitialEdgeIds(io::Address coordinator_address) override { + return {}; + } + + void InstallSimulatorTicker(std::function<bool()> tick_simulator) override {} + const std::vector<coordinator::SchemaProperty> &GetSchemaForLabel(storage::v3::LabelId /*label*/) const override { + static std::vector<coordinator::SchemaProperty> schema; + return schema; + }; + private: void SetUpNameIdMappers() { std::unordered_map<uint64_t, std::string> id_to_name; diff --git a/tests/unit/query_v2_plan.cpp b/tests/unit/query_v2_plan.cpp new file mode 100644 index 000000000..8fc5e11c7 --- /dev/null +++ b/tests/unit/query_v2_plan.cpp @@ -0,0 +1,178 @@ +// Copyright 2023 Memgraph Ltd. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source +// License, and you may not use this file except in compliance with the Business Source License. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +#include "query_plan_checker_v2.hpp" + +#include <iostream> +#include <list> +#include <sstream> +#include <tuple> +#include <typeinfo> +#include <unordered_set> +#include <variant> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +#include "expr/semantic/symbol_generator.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/v2/frontend/ast/ast.hpp" +#include "query/v2/plan/operator.hpp" +#include "query/v2/plan/planner.hpp" + +#include "query_v2_common.hpp" + +namespace memgraph::query { +::std::ostream &operator<<(::std::ostream &os, const Symbol &sym) { + return os << "Symbol{\"" << sym.name() << "\" [" << sym.position() << "] " << Symbol::TypeToString(sym.type()) << "}"; +} +} // namespace memgraph::query + +// using namespace memgraph::query::v2::plan; +using namespace memgraph::expr::plan; +using memgraph::query::Symbol; +using memgraph::query::SymbolGenerator; +using memgraph::query::v2::AstStorage; +using memgraph::query::v2::SingleQuery; +using memgraph::query::v2::SymbolTable; +using Type = memgraph::query::v2::EdgeAtom::Type; +using Direction = memgraph::query::v2::EdgeAtom::Direction; +using Bound = ScanAllByLabelPropertyRange::Bound; + +namespace { + +class Planner { + public: + template <class TDbAccessor> + Planner(std::vector<SingleQueryPart> single_query_parts, PlanningContext<TDbAccessor> context) { + memgraph::expr::Parameters parameters; + PostProcessor post_processor(parameters); + plan_ = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(single_query_parts, &context); + plan_ = post_processor.Rewrite(std::move(plan_), &context); + } + + auto &plan() { return *plan_; } + + private: + std::unique_ptr<LogicalOperator> plan_; +}; + +template <class... TChecker> +auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker... checker) { + std::list<BaseOpChecker *> checkers{&checker...}; + PlanChecker plan_checker(checkers, symbol_table); + plan.Accept(plan_checker); + EXPECT_TRUE(plan_checker.checkers_.empty()); +} + +template <class TPlanner, class... TChecker> +auto CheckPlan(memgraph::query::v2::CypherQuery *query, AstStorage &storage, TChecker... checker) { + auto symbol_table = memgraph::expr::MakeSymbolTable(query); + FakeDistributedDbAccessor dba; + auto planner = MakePlanner<TPlanner>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, checker...); +} + +template <class T> +class TestPlanner : public ::testing::Test {}; + +using PlannerTypes = ::testing::Types<Planner>; + +TYPED_TEST_CASE(TestPlanner, PlannerTypes); + +TYPED_TEST(TestPlanner, MatchFilterPropIsNotNull) { + const char *prim_label_name = "prim_label_one"; + // Exact primary key match, one elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second}); + + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // Exact primary key match, two elem as PK. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + memgraph::query::v2::Expression *expected_primary_key; + expected_primary_key = PROPERTY_LOOKUP("n", prim_prop_one); + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(AND(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1)), + EQ(PROPERTY_LOOKUP("n", prim_prop_two), LITERAL(1)))), + RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanByPrimaryKey(label, {expected_primary_key}), ExpectProduce()); + } + // One elem is missing from PK, default to ScanAllByLabelPropertyValue. + { + FakeDistributedDbAccessor dba; + auto label = dba.Label(prim_label_name); + + auto prim_prop_one = PRIMARY_PROPERTY_PAIR("prim_prop_one"); + auto prim_prop_two = PRIMARY_PROPERTY_PAIR("prim_prop_two"); + + auto sec_prop_one = PRIMARY_PROPERTY_PAIR("sec_prop_one"); + auto sec_prop_two = PRIMARY_PROPERTY_PAIR("sec_prop_two"); + auto sec_prop_three = PRIMARY_PROPERTY_PAIR("sec_prop_three"); + + dba.SetIndexCount(label, 1); + dba.SetIndexCount(label, prim_prop_one.second, 1); + + dba.CreateSchema(label, {prim_prop_one.second, prim_prop_two.second}); + + dba.SetIndexCount(label, prim_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_one.second, 1); + dba.SetIndexCount(label, sec_prop_two.second, 1); + dba.SetIndexCount(label, sec_prop_three.second, 1); + memgraph::query::v2::AstStorage storage; + + auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", prim_label_name))), + WHERE(EQ(PROPERTY_LOOKUP("n", prim_prop_one), LITERAL(1))), RETURN("n"))); + auto symbol_table = (memgraph::expr::MakeSymbolTable(query)); + auto planner = MakePlanner<TypeParam>(&dba, storage, symbol_table, query); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, prim_prop_one, IDENT("n")), + ExpectProduce()); + } +} + +} // namespace