From 252018ab2213d19baaab5abcc961ecf9d6c0af19 Mon Sep 17 00:00:00 2001 From: Teon Banek Date: Mon, 22 Jan 2018 10:27:00 +0100 Subject: [PATCH] Serialize query plan operators Summary: With the added support for serialization, we should be able to transfer plans across distributed workers. The planner tests has been extended to test serialization. Operators should be mostly tested, but the expression they contain aren't completely. The quick solution is to use typeid for rudimentary expression equality testing. The more involved solution of comparing the expression tree for equality is the correct choice. It should be done in the near future. Reviewers: florijan, msantl Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1122 --- src/query/frontend/ast/ast.cpp | 20 +- src/query/frontend/ast/ast.hpp | 25 +- src/query/frontend/semantic/symbol.hpp | 3 + src/query/interpret/eval.hpp | 1 + src/query/plan/operator.cpp | 47 +- src/query/plan/operator.hpp | 762 ++++++++++++++++++++++--- tests/unit/query_planner.cpp | 555 ++++++++++-------- 7 files changed, 1059 insertions(+), 354 deletions(-) diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index 16c48bf80..a9ee568ef 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -7,6 +7,18 @@ namespace query { +// Id for boost's archive get_helper needs to be unique among all ids. If it +// isn't unique, then different instances (even across different types!) will +// replace the previous helper. The documentation recommends to take an address +// of a function, which should be unique in the produced binary. +// It might seem logical to take an address of a AstTreeStorage constructor, but +// according to c++ standard, the constructor function need not have an address. +// Additionally, pointers to member functions are not required to contain the +// address of the function +// (https://isocpp.org/wiki/faq/pointers-to-members#addr-of-memfn). So, to be +// safe, use a regular top-level function. +void *const AstTreeStorage::kHelperId = (void *)CloneReturnBody; + AstTreeStorage::AstTreeStorage() { storage_.emplace_back(new Query(next_uid_++)); } @@ -15,14 +27,6 @@ Query *AstTreeStorage::query() const { return dynamic_cast(storage_[0].get()); } -int AstTreeStorage::MaximumStorageUid() const { - int max_uid = -1; - for (const auto &tree : storage_) { - max_uid = std::max(max_uid, tree->uid()); - } - return max_uid; -} - ReturnBody CloneReturnBody(AstTreeStorage &storage, const ReturnBody &body) { ReturnBody new_body; new_body.distinct = body.distinct; diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 00f5a998f..589a6d0a8 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -9,13 +8,10 @@ #include "boost/serialization/split_member.hpp" #include "boost/serialization/string.hpp" #include "boost/serialization/vector.hpp" -#include "glog/logging.h" -#include "database/graph_db.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/semantic/symbol.hpp" #include "query/interpret/awesome_memgraph_functions.hpp" -#include "query/parameters.hpp" #include "query/typed_value.hpp" #include "storage/types.hpp" #include "utils/serialization.hpp" @@ -35,6 +31,10 @@ struct hash> { }; } // namespace std +namespace database { +class GraphDbAccessor; +} + namespace query { #define CLONE_BINARY_EXPRESSION \ @@ -60,6 +60,7 @@ namespace query { class Tree; + // It would be better to call this AstTree, but we already have a class Tree, // which could be renamed to Node or AstTreeNode, but we also have a class // called NodeAtom... @@ -82,14 +83,16 @@ class AstTreeStorage { Query *query() const; + /// Id for using get_helper in boost archives. + static void * const kHelperId; + /// Load an Ast Node into this storage. template void Load(TArchive &ar, TNode &node) { - auto &tmp_ast = ar.template get_helper(); - tmp_ast.storage_ = std::move(storage_); + auto &tmp_ast = ar.template get_helper(kHelperId); + std::swap(*this, tmp_ast); ar >> node; - storage_ = std::move(tmp_ast.storage_); - next_uid_ = MaximumStorageUid() + 1; + std::swap(*this, tmp_ast); } /// Load a Query into this storage. @@ -102,8 +105,6 @@ class AstTreeStorage { int next_uid_ = 0; std::vector> storage_; - int MaximumStorageUid() const; - template friend void LoadPointer(TArchive &ar, TNode *&node); }; @@ -117,7 +118,8 @@ template void LoadPointer(TArchive &ar, TNode *&node) { ar >> node; if (node) { - auto &ast_storage = ar.template get_helper(); + auto &ast_storage = + ar.template get_helper(AstTreeStorage::kHelperId); auto found = std::find_if(ast_storage.storage_.begin(), ast_storage.storage_.end(), [&](const auto &n) { return n->uid() == node->uid(); }); @@ -127,6 +129,7 @@ void LoadPointer(TArchive &ar, TNode *&node) { dynamic_cast(found->get()) == node); if (ast_storage.storage_.end() == found) { ast_storage.storage_.emplace_back(node); + ast_storage.next_uid_ = std::max(ast_storage.next_uid_, node->uid() + 1); } } } diff --git a/src/query/frontend/semantic/symbol.hpp b/src/query/frontend/semantic/symbol.hpp index 3e6e85e12..a6becfd19 100644 --- a/src/query/frontend/semantic/symbol.hpp +++ b/src/query/frontend/semantic/symbol.hpp @@ -2,6 +2,9 @@ #include +#include "boost/serialization/serialization.hpp" +#include "boost/serialization/string.hpp" + namespace query { class Symbol { diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index a01782617..a4d6f9b83 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -11,6 +11,7 @@ #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/interpret/frame.hpp" +#include "query/parameters.hpp" #include "query/typed_value.hpp" #include "utils/exceptions.hpp" diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index a091e0181..47845bf5c 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1,16 +1,23 @@ +#include "query/plan/operator.hpp" + #include #include +#include #include #include +#include "boost/archive/binary_iarchive.hpp" +#include "boost/archive/binary_oarchive.hpp" +#include "boost/serialization/export.hpp" #include "glog/logging.h" +#include "database/graph_db_accessor.hpp" #include "query/context.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/frontend/semantic/symbol_table.hpp" #include "query/interpret/eval.hpp" - -#include "query/plan/operator.hpp" +#include "query/path.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -79,7 +86,7 @@ std::unique_ptr Once::MakeCursor(database::GraphDbAccessor &) const { void Once::OnceCursor::Reset() { did_pull_ = false; } -CreateNode::CreateNode(const NodeAtom *node_atom, +CreateNode::CreateNode(NodeAtom *node_atom, const std::shared_ptr &input) : node_atom_(node_atom), input_(input ? input : std::make_shared()) {} @@ -117,7 +124,7 @@ void CreateNode::CreateNodeCursor::Create(Frame &frame, Context &context) { frame[context.symbol_table_.at(*self_.node_atom_->identifier_)] = new_node; } -CreateExpand::CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom, +CreateExpand::CreateExpand(NodeAtom *node_atom, EdgeAtom *edge_atom, const std::shared_ptr &input, Symbol input_symbol, bool existing_node) : node_atom_(node_atom), @@ -2546,3 +2553,35 @@ void Union::UnionCursor::Reset() { } } // namespace query::plan + +BOOST_CLASS_EXPORT(query::plan::Once); +BOOST_CLASS_EXPORT(query::plan::CreateNode); +BOOST_CLASS_EXPORT(query::plan::CreateExpand); +BOOST_CLASS_EXPORT(query::plan::ScanAll); +BOOST_CLASS_EXPORT(query::plan::ScanAllByLabel); +BOOST_CLASS_EXPORT(query::plan::ScanAllByLabelPropertyRange); +BOOST_CLASS_EXPORT(query::plan::ScanAllByLabelPropertyValue); +BOOST_CLASS_EXPORT(query::plan::Expand); +BOOST_CLASS_EXPORT(query::plan::ExpandVariable); +BOOST_CLASS_EXPORT(query::plan::Filter); +BOOST_CLASS_EXPORT(query::plan::Produce); +BOOST_CLASS_EXPORT(query::plan::ConstructNamedPath); +BOOST_CLASS_EXPORT(query::plan::Delete); +BOOST_CLASS_EXPORT(query::plan::SetProperty); +BOOST_CLASS_EXPORT(query::plan::SetProperties); +BOOST_CLASS_EXPORT(query::plan::SetLabels); +BOOST_CLASS_EXPORT(query::plan::RemoveProperty); +BOOST_CLASS_EXPORT(query::plan::RemoveLabels); +BOOST_CLASS_EXPORT(query::plan::ExpandUniquenessFilter); +BOOST_CLASS_EXPORT(query::plan::ExpandUniquenessFilter); +BOOST_CLASS_EXPORT(query::plan::Accumulate); +BOOST_CLASS_EXPORT(query::plan::Aggregate); +BOOST_CLASS_EXPORT(query::plan::Skip); +BOOST_CLASS_EXPORT(query::plan::Limit); +BOOST_CLASS_EXPORT(query::plan::OrderBy); +BOOST_CLASS_EXPORT(query::plan::Merge); +BOOST_CLASS_EXPORT(query::plan::Optional); +BOOST_CLASS_EXPORT(query::plan::Unwind); +BOOST_CLASS_EXPORT(query::plan::Distinct); +BOOST_CLASS_EXPORT(query::plan::CreateIndex); +BOOST_CLASS_EXPORT(query::plan::Union); diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 0161f87b1..000c4d41a 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -2,33 +2,38 @@ #pragma once -#include -#include -#include #include #include -#include -#include #include #include +#include #include -#include "database/graph_db_accessor.hpp" +#include +#include "boost/serialization/base_object.hpp" +#include "boost/serialization/serialization.hpp" +#include "boost/serialization/shared_ptr.hpp" +#include "boost/serialization/unique_ptr.hpp" + #include "query/common.hpp" -#include "query/exceptions.hpp" -#include "query/frontend/semantic/symbol_table.hpp" -#include "query/path.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/frontend/semantic/symbol.hpp" #include "query/typed_value.hpp" #include "storage/types.hpp" #include "utils/bound.hpp" #include "utils/hashing/fnv.hpp" #include "utils/visitor.hpp" +namespace database { +class GraphDbAccessor; +} + namespace query { class Context; class ExpressionEvaluator; class Frame; +class SymbolTable; namespace plan { @@ -150,8 +155,23 @@ class LogicalOperator } virtual ~LogicalOperator() {} + + private: + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) {} }; +template +std::pair, AstTreeStorage> LoadPlan( + TArchive &ar) { + std::unique_ptr root; + ar >> root; + return {std::move(root), std::move(ar.template get_helper( + AstTreeStorage::kHelperId))}; +} + /** * A logical operator whose Cursor returns true on the first Pull * and false on every following Pull. @@ -172,6 +192,12 @@ class Once : public LogicalOperator { private: bool did_pull_{false}; }; + + friend class boost::serialization::access; + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + } }; /** @brief Operator for creating a node. @@ -192,15 +218,17 @@ class CreateNode : public LogicalOperator { * If a valid input, then a node will be created for each * successful pull from the given input. */ - CreateNode(const NodeAtom *node_atom, + CreateNode(NodeAtom *node_atom, const std::shared_ptr &input); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor( database::GraphDbAccessor &db) const override; private: - const NodeAtom *node_atom_ = nullptr; - const std::shared_ptr input_; + CreateNode() {} + + NodeAtom *node_atom_ = nullptr; + std::shared_ptr input_; class CreateNodeCursor : public Cursor { public: @@ -218,6 +246,24 @@ class CreateNode : public LogicalOperator { */ void Create(Frame &, Context &); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, node_atom_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, node_atom_); + } }; /** @brief Operator for creating edges and destination nodes. @@ -247,7 +293,7 @@ class CreateExpand : public LogicalOperator { * @param existing_node @c bool indicating whether the @c node_atom refers to * an existing node. If @c false, the operator will also create the node. */ - CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom, + CreateExpand(NodeAtom *node_atom, EdgeAtom *edge_atom, const std::shared_ptr &input, Symbol input_symbol, bool existing_node); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; @@ -256,17 +302,19 @@ class CreateExpand : public LogicalOperator { private: // info on what's getting expanded - const NodeAtom *node_atom_; - const EdgeAtom *edge_atom_; + NodeAtom *node_atom_; + EdgeAtom *edge_atom_; // the input op and the symbol under which the op's result // can be found in the frame - const std::shared_ptr input_; - const Symbol input_symbol_; + std::shared_ptr input_; + Symbol input_symbol_; // if the given node atom refers to an existing node // (either matched or created) - const bool existing_node_; + bool existing_node_; + + CreateExpand() {} class CreateExpandCursor : public Cursor { public: @@ -298,6 +346,30 @@ class CreateExpand : public LogicalOperator { const SymbolTable &symbol_table, ExpressionEvaluator &evaluator); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + SavePointer(ar, node_atom_); + SavePointer(ar, edge_atom_); + ar &input_; + ar &input_symbol_; + ar &existing_node_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + LoadPointer(ar, node_atom_); + LoadPointer(ar, edge_atom_); + ar &input_; + ar &input_symbol_; + ar &existing_node_; + } }; /** @@ -329,8 +401,8 @@ class ScanAll : public LogicalOperator { auto graph_view() const { return graph_view_; } protected: - const std::shared_ptr input_; - const Symbol output_symbol_; + std::shared_ptr input_; + Symbol output_symbol_; /** * @brief Controls which graph state is used to produce vertices. * @@ -339,7 +411,20 @@ class ScanAll : public LogicalOperator { * command. With @c GraphView::NEW, all vertices will be produced the current * transaction sees along with their modifications. */ - const GraphView graph_view_; + GraphView graph_view_; + + ScanAll() {} + + private: + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &output_symbol_; + ar &graph_view_; + } }; /** @@ -362,7 +447,17 @@ class ScanAllByLabel : public ScanAll { storage::Label label() const { return label_; } private: - const storage::Label label_; + storage::Label label_; + + ScanAllByLabel() {} + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &label_; + } }; /** @@ -408,10 +503,57 @@ class ScanAllByLabelPropertyRange : public ScanAll { auto upper_bound() const { return upper_bound_; } private: - const storage::Label label_; - const storage::Property property_; + storage::Label label_; + storage::Property property_; std::experimental::optional lower_bound_; std::experimental::optional upper_bound_; + + ScanAllByLabelPropertyRange() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &label_; + ar &property_; + auto save_bound = [&ar](auto &maybe_bound) { + if (!maybe_bound) { + ar & false; + return; + } + ar & true; + auto &bound = *maybe_bound; + ar &bound.type(); + SavePointer(ar, bound.value()); + }; + save_bound(lower_bound_); + save_bound(upper_bound_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &label_; + ar &property_; + auto load_bound = [&ar](auto &maybe_bound) { + bool has_bound = false; + ar &has_bound; + if (!has_bound) { + maybe_bound = std::experimental::nullopt; + return; + } + utils::BoundType type; + ar &type; + Expression *value; + LoadPointer(ar, value); + maybe_bound = std::experimental::make_optional(Bound(value, type)); + }; + load_bound(lower_bound_); + load_bound(upper_bound_); + } }; /** @@ -449,9 +591,31 @@ class ScanAllByLabelPropertyValue : public ScanAll { auto expression() const { return expression_; } private: - const storage::Label label_; - const storage::Property property_; + storage::Label label_; + storage::Property property_; Expression *expression_; + + ScanAllByLabelPropertyValue() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &label_; + ar &property_; + SavePointer(ar, expression_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &label_; + ar &property_; + LoadPointer(ar, expression_); + } }; /** @@ -511,22 +675,22 @@ class ExpandCommon { protected: // info on what's getting expanded - const Symbol node_symbol_; - const Symbol edge_symbol_; - const EdgeAtom::Direction direction_; - const std::vector edge_types_; + Symbol node_symbol_; + Symbol edge_symbol_; + EdgeAtom::Direction direction_; + std::vector edge_types_; // the input op and the symbol under which the op's result // can be found in the frame - const std::shared_ptr input_; - const Symbol input_symbol_; + std::shared_ptr input_; + Symbol input_symbol_; // If the given node atom refer to a symbol that has already been expanded and // should be just validated in the frame. - const bool existing_node_; + bool existing_node_; // from which state the input node should get expanded - const GraphView graph_view_; + GraphView graph_view_; /** * For a newly expanded node handles existence checking and @@ -537,6 +701,23 @@ class ExpandCommon { * the old. */ bool HandleExistingNode(const VertexAccessor &new_node, Frame &frame) const; + + ExpandCommon() {} + + private: + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &node_symbol_; + ar &edge_symbol_; + ar &direction_; + ar &edge_types_; + ar &input_; + ar &input_symbol_; + ar &existing_node_; + ar &graph_view_; + } }; /** @@ -586,6 +767,15 @@ class Expand : public LogicalOperator, public ExpandCommon { bool InitEdges(Frame &, Context &); }; + + private: + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &boost::serialization::base_object(*this); + } }; /** @@ -655,20 +845,52 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon { auto type() const { return type_; } private: - const EdgeAtom::Type type_; + EdgeAtom::Type type_; // True if the path should be written as expanding from node_symbol to // input_symbol. - const bool is_reverse_; + bool is_reverse_; // lower and upper bounds of the variable length expansion // both are optional, defaults are (1, inf) Expression *lower_bound_; Expression *upper_bound_; // symbols for a single node and edge that are currently getting expanded - const Symbol inner_edge_symbol_; - const Symbol inner_node_symbol_; + Symbol inner_edge_symbol_; + Symbol inner_node_symbol_; // a filtering expression for skipping expansions during expansion // can refer to inner node and edges Expression *filter_; + + ExpandVariable() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &boost::serialization::base_object(*this); + ar &type_; + ar &is_reverse_; + SavePointer(ar, lower_bound_); + SavePointer(ar, upper_bound_); + ar &inner_edge_symbol_; + ar &inner_node_symbol_; + SavePointer(ar, filter_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &boost::serialization::base_object(*this); + ar &type_; + ar &is_reverse_; + LoadPointer(ar, lower_bound_); + LoadPointer(ar, upper_bound_); + ar &inner_edge_symbol_; + ar &inner_node_symbol_; + LoadPointer(ar, filter_); + } }; /** @@ -691,9 +913,21 @@ class ConstructNamedPath : public LogicalOperator { const auto &path_elements() const { return path_elements_; } private: - const std::shared_ptr input_; - const Symbol path_symbol_; - const std::vector path_elements_; + std::shared_ptr input_; + Symbol path_symbol_; + std::vector path_elements_; + + ConstructNamedPath() {} + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &path_symbol_; + ar &path_elements_; + } }; /** @@ -713,9 +947,11 @@ class Filter : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; + std::shared_ptr input_; Expression *expression_; + Filter() {} + class FilterCursor : public Cursor { public: FilterCursor(const Filter &self, database::GraphDbAccessor &db); @@ -727,6 +963,24 @@ class Filter : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, expression_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, expression_); + } }; /** @@ -751,8 +1005,8 @@ class Produce : public LogicalOperator { const std::vector &named_expressions(); private: - const std::shared_ptr input_; - const std::vector named_expressions_; + std::shared_ptr input_; + std::vector named_expressions_; class ProduceCursor : public Cursor { public: @@ -765,6 +1019,16 @@ class Produce : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr input_cursor_; }; + + Produce() {} + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &named_expressions_; + } }; /** @@ -782,12 +1046,14 @@ class Delete : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; - const std::vector expressions_; + std::shared_ptr input_; + std::vector expressions_; // if the vertex should be detached before deletion // if not detached, and has connections, an error is raised // ignored when deleting edges - const bool detach_; + bool detach_; + + Delete() {} class DeleteCursor : public Cursor { public: @@ -800,6 +1066,26 @@ class Delete : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointers(ar, expressions_); + ar &detach_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointers(ar, expressions_); + ar &detach_; + } }; /** @@ -817,10 +1103,12 @@ class SetProperty : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; + std::shared_ptr input_; PropertyLookup *lhs_; Expression *rhs_; + SetProperty() {} + class SetPropertyCursor : public Cursor { public: SetPropertyCursor(const SetProperty &self, database::GraphDbAccessor &db); @@ -832,6 +1120,26 @@ class SetProperty : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, lhs_); + SavePointer(ar, rhs_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, lhs_); + LoadPointer(ar, rhs_); + } }; /** @@ -865,11 +1173,13 @@ class SetProperties : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; - const Symbol input_symbol_; + std::shared_ptr input_; + Symbol input_symbol_; Expression *rhs_; Op op_; + SetProperties() {} + class SetPropertiesCursor : public Cursor { public: SetPropertiesCursor(const SetProperties &self, @@ -890,6 +1200,28 @@ class SetProperties : public LogicalOperator { template void Set(TRecordAccessor &record, const TypedValue &rhs) const; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &input_symbol_; + SavePointer(ar, rhs_); + ar &op_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &input_symbol_; + LoadPointer(ar, rhs_); + ar &op_; + } }; /** @@ -907,9 +1239,11 @@ class SetLabels : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; - const Symbol input_symbol_; - const std::vector labels_; + std::shared_ptr input_; + Symbol input_symbol_; + std::vector labels_; + + SetLabels() {} class SetLabelsCursor : public Cursor { public: @@ -921,6 +1255,16 @@ class SetLabels : public LogicalOperator { const SetLabels &self_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &input_symbol_; + ar &labels_; + } }; /** @@ -936,9 +1280,11 @@ class RemoveProperty : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; + std::shared_ptr input_; PropertyLookup *lhs_; + RemoveProperty() {} + class RemovePropertyCursor : public Cursor { public: RemovePropertyCursor(const RemoveProperty &self, @@ -951,6 +1297,24 @@ class RemoveProperty : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, lhs_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, lhs_); + } }; /** @@ -968,9 +1332,11 @@ class RemoveLabels : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; - const Symbol input_symbol_; - const std::vector labels_; + std::shared_ptr input_; + Symbol input_symbol_; + std::vector labels_; + + RemoveLabels() {} class RemoveLabelsCursor : public Cursor { public: @@ -982,6 +1348,16 @@ class RemoveLabels : public LogicalOperator { const RemoveLabels &self_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &input_symbol_; + ar &labels_; + } }; /** @@ -1018,9 +1394,11 @@ class ExpandUniquenessFilter : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr input_; + std::shared_ptr input_; Symbol expand_symbol_; - const std::vector previous_symbols_; + std::vector previous_symbols_; + + ExpandUniquenessFilter() {} class ExpandUniquenessFilterCursor : public Cursor { public: @@ -1033,6 +1411,16 @@ class ExpandUniquenessFilter : public LogicalOperator { const ExpandUniquenessFilter &self_; const std::unique_ptr input_cursor_; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &expand_symbol_; + ar &previous_symbols_; + } }; /** @brief Pulls everything from the input before passing it through. @@ -1073,9 +1461,11 @@ class Accumulate : public LogicalOperator { const auto &symbols() const { return symbols_; }; private: - const std::shared_ptr input_; - const std::vector symbols_; - const bool advance_command_; + std::shared_ptr input_; + std::vector symbols_; + bool advance_command_; + + Accumulate() {} class AccumulateCursor : public Cursor { public: @@ -1091,6 +1481,16 @@ class Accumulate : public LogicalOperator { decltype(cache_.begin()) cache_it_ = cache_.begin(); bool pulled_all_input_{false}; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &symbols_; + ar &advance_command_; + } }; /** @@ -1127,6 +1527,27 @@ class Aggregate : public LogicalOperator { Expression *key; Aggregation::Op op; Symbol output_sym; + + private: + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + SavePointer(ar, value); + SavePointer(ar, key); + ar &op; + ar &output_sym; + } + + template + void load(TArchive &ar, const unsigned int) { + LoadPointer(ar, value); + LoadPointer(ar, key); + ar &op; + ar &output_sym; + } }; Aggregate(const std::shared_ptr &input, @@ -1141,10 +1562,12 @@ class Aggregate : public LogicalOperator { const auto &group_by() const { return group_by_; } private: - const std::shared_ptr input_; - const std::vector aggregations_; - const std::vector group_by_; - const std::vector remember_; + std::shared_ptr input_; + std::vector aggregations_; + std::vector group_by_; + std::vector remember_; + + Aggregate() {} class AggregateCursor : public Cursor { public: @@ -1224,6 +1647,28 @@ class Aggregate : public LogicalOperator { * an appropriate exception is thrown. */ void EnsureOkForAvgSum(const TypedValue &value) const; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &aggregations_; + SavePointers(ar, group_by_); + ar &remember_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &aggregations_; + LoadPointers(ar, group_by_); + ar &remember_; + } }; /** @brief Skips a number of Pulls from the input op. @@ -1247,9 +1692,11 @@ class Skip : public LogicalOperator { std::vector OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr input_; + std::shared_ptr input_; Expression *expression_; + Skip() {} + class SkipCursor : public Cursor { public: SkipCursor(const Skip &self, database::GraphDbAccessor &db); @@ -1265,6 +1712,24 @@ class Skip : public LogicalOperator { int to_skip_{-1}; int skipped_{0}; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, expression_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, expression_); + } }; /** @brief Limits the number of Pulls from the input op. @@ -1291,9 +1756,11 @@ class Limit : public LogicalOperator { std::vector OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr input_; + std::shared_ptr input_; Expression *expression_; + Limit() {} + class LimitCursor : public Cursor { public: LimitCursor(const Limit &self, database::GraphDbAccessor &db); @@ -1309,6 +1776,24 @@ class Limit : public LogicalOperator { int limit_{-1}; int pulled_{0}; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, expression_); + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, expression_); + } }; /** @brief Logical operator for ordering (sorting) results. @@ -1349,12 +1834,21 @@ class OrderBy : public LogicalOperator { private: std::vector ordering_; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &ordering_; + } }; - const std::shared_ptr input_; + std::shared_ptr input_; TypedValueVectorCompare compare_; std::vector order_by_; - const std::vector output_symbols_; + std::vector output_symbols_; + + OrderBy() {} // custom comparison for TypedValue objects // behaves generally like Neo's ORDER BY comparison operator: @@ -1383,6 +1877,28 @@ class OrderBy : public LogicalOperator { // iterator over the cache_, maintains state between Pulls decltype(cache_.begin()) cache_it_ = cache_.begin(); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &compare_; + SavePointers(ar, order_by_); + ar &output_symbols_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &compare_; + LoadPointers(ar, order_by_); + ar &output_symbols_; + } }; /** @@ -1412,9 +1928,11 @@ class Merge : public LogicalOperator { auto merge_create() const { return merge_create_; } private: - const std::shared_ptr input_; - const std::shared_ptr merge_match_; - const std::shared_ptr merge_create_; + std::shared_ptr input_; + std::shared_ptr merge_match_; + std::shared_ptr merge_create_; + + Merge() {} class MergeCursor : public Cursor { public: @@ -1434,6 +1952,16 @@ class Merge : public LogicalOperator { // - previous Pull from this cursor exhausted the merge_match_cursor bool pull_input_{true}; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &merge_match_; + ar &merge_create_; + } }; /** @@ -1458,9 +1986,11 @@ class Optional : public LogicalOperator { const auto &optional_symbols() const { return optional_symbols_; } private: - const std::shared_ptr input_; - const std::shared_ptr optional_; - const std::vector optional_symbols_; + std::shared_ptr input_; + std::shared_ptr optional_; + std::vector optional_symbols_; + + Optional() {} class OptionalCursor : public Cursor { public: @@ -1479,6 +2009,16 @@ class Optional : public LogicalOperator { // - previous Pull from this cursor exhausted the optional_cursor_ bool pull_input_{true}; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &optional_; + ar &optional_symbols_; + } }; /** @@ -1498,9 +2038,11 @@ class Unwind : public LogicalOperator { Expression *input_expression() const { return input_expression_; } private: - const std::shared_ptr input_; + std::shared_ptr input_; Expression *input_expression_; - const Symbol output_symbol_; + Symbol output_symbol_; + + Unwind() {} class UnwindCursor : public Cursor { public: @@ -1517,6 +2059,26 @@ class Unwind : public LogicalOperator { // current position in input_value_ std::vector::iterator input_value_it_ = input_value_.end(); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object(*this); + ar &input_; + SavePointer(ar, input_expression_); + ar &output_symbol_; + } + + template + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + LoadPointer(ar, input_expression_); + ar &output_symbol_; + } }; /** @@ -1537,8 +2099,10 @@ class Distinct : public LogicalOperator { std::vector OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr input_; - const std::vector value_symbols_; + std::shared_ptr input_; + std::vector value_symbols_; + + Distinct() {} class DistinctCursor : public Cursor { public: @@ -1558,6 +2122,15 @@ class Distinct : public LogicalOperator { TypedValueVectorEqual> seen_rows_; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &input_; + ar &value_symbols_; + } }; /** @@ -1581,6 +2154,17 @@ class CreateIndex : public LogicalOperator { private: storage::Label label_; storage::Property property_; + + CreateIndex() {} + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &label_; + ar &property_; + } }; /** @@ -1603,8 +2187,10 @@ class Union : public LogicalOperator { std::vector OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr left_op_, right_op_; - const std::vector union_symbols_, left_symbols_, right_symbols_; + std::shared_ptr left_op_, right_op_; + std::vector union_symbols_, left_symbols_, right_symbols_; + + Union() {} class UnionCursor : public Cursor { public: @@ -1616,6 +2202,18 @@ class Union : public LogicalOperator { const Union &self_; const std::unique_ptr left_cursor_, right_cursor_; }; + + friend class boost::serialization::access; + + template + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object(*this); + ar &left_op_; + ar &right_op_; + ar &union_symbols_; + ar &left_symbols_; + ar &right_symbols_; + } }; } // namespace plan diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index e78299861..4f1ba1527 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -1,7 +1,11 @@ #include +#include #include +#include #include +#include "boost/archive/binary_iarchive.hpp" +#include "boost/archive/binary_oarchive.hpp" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -182,15 +186,23 @@ class ExpectAggregate : public OpChecker { for (const auto &aggr_elem : op.aggregations()) { ASSERT_NE(aggr_it, aggregations_.end()); auto aggr = *aggr_it++; - EXPECT_EQ(aggr_elem.value, aggr->expression1_); - EXPECT_EQ(aggr_elem.key, aggr->expression2_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(aggr_elem.value).hash_code(), + typeid(aggr->expression1_).hash_code()); + EXPECT_EQ(typeid(aggr_elem.key).hash_code(), + typeid(aggr->expression2_).hash_code()); EXPECT_EQ(aggr_elem.op, aggr->op_); EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); } EXPECT_EQ(aggr_it, aggregations_.end()); - auto got_group_by = std::unordered_set( - op.group_by().begin(), op.group_by().end()); - EXPECT_EQ(group_by_, got_group_by); + // TODO: Proper group by expression equality + std::unordered_set got_group_by; + for (auto *expr : op.group_by()) + got_group_by.insert(typeid(*expr).hash_code()); + std::unordered_set expected_group_by; + for (auto *expr : group_by_) + expected_group_by.insert(typeid(*expr).hash_code()); + EXPECT_EQ(got_group_by, expected_group_by); } private: @@ -252,7 +264,9 @@ class ExpectScanAllByLabelPropertyValue const SymbolTable &) override { EXPECT_EQ(scan_all.label(), label_); EXPECT_EQ(scan_all.property(), property_); - EXPECT_EQ(scan_all.expression(), expression_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.expression()).hash_code(), + typeid(expression_).hash_code()); } private: @@ -279,12 +293,16 @@ class ExpectScanAllByLabelPropertyRange EXPECT_EQ(scan_all.property(), property_); if (lower_bound_) { ASSERT_TRUE(scan_all.lower_bound()); - EXPECT_EQ(scan_all.lower_bound()->value(), lower_bound_->value()); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.lower_bound()->value()).hash_code(), + typeid(lower_bound_->value()).hash_code()); EXPECT_EQ(scan_all.lower_bound()->type(), lower_bound_->type()); } if (upper_bound_) { ASSERT_TRUE(scan_all.upper_bound()); - EXPECT_EQ(scan_all.upper_bound()->value(), upper_bound_->value()); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.upper_bound()->value()).hash_code(), + typeid(upper_bound_->value()).hash_code()); EXPECT_EQ(scan_all.upper_bound()->type(), upper_bound_->type()); } } @@ -318,6 +336,44 @@ auto MakeSymbolTable(query::Query &query) { return symbol_table; } +class Planner { + public: + Planner(std::vector single_query_parts, + PlanningContext &context) { + plan_ = MakeLogicalPlanForSingleQuery(single_query_parts, + context); + } + + auto &plan() { return *plan_; } + + private: + std::unique_ptr plan_; +}; + +class SerializedPlanner { + public: + SerializedPlanner(std::vector single_query_parts, + PlanningContext &context) { + std::stringstream stream; + { + auto original_plan = MakeLogicalPlanForSingleQuery( + single_query_parts, context); + boost::archive::binary_oarchive out_archive(stream); + out_archive << original_plan; + } + { + boost::archive::binary_iarchive in_archive(stream); + std::tie(plan_, ast_storage_) = LoadPlan(in_archive); + } + } + + auto &plan() { return *plan_; } + + private: + AstTreeStorage ast_storage_; + std::unique_ptr plan_; +}; + template auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker... checker) { @@ -327,7 +383,7 @@ auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, EXPECT_TRUE(plan_checker.checkers_.empty()); } -template +template auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { auto symbol_table = MakeSymbolTable(*storage.query()); database::SingleNode db; @@ -336,19 +392,25 @@ auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, checker...); + TPlanner planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, checker...); } -TEST(TestLogicalPlanner, MatchNodeReturn) { +template +class TestPlanner : public ::testing::Test {}; + +using PlannerTypes = ::testing::Types; + +TYPED_TEST_CASE(TestPlanner, PlannerTypes); + +TYPED_TEST(TestPlanner, MatchNodeReturn) { // Test MATCH (n) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce()); } -TEST(TestLogicalPlanner, CreateNodeReturn) { +TYPED_TEST(TestPlanner, CreateNodeReturn) { // Test CREATE (n) RETURN n AS n AstTreeStorage storage; auto ident_n = IDENT("n"); @@ -362,12 +424,12 @@ TEST(TestLogicalPlanner, CreateNodeReturn) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, + ExpectProduce()); } -TEST(TestLogicalPlanner, CreateExpand) { +TYPED_TEST(TestPlanner, CreateExpand) { // Test CREATE (n) -[r :rel1]-> (m) AstTreeStorage storage; database::SingleNode db; @@ -375,17 +437,17 @@ TEST(TestLogicalPlanner, CreateExpand) { auto relationship = dba.EdgeType("relationship"); QUERY(SINGLE_QUERY(CREATE(PATTERN( NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand()); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand()); } -TEST(TestLogicalPlanner, CreateMultipleNode) { +TYPED_TEST(TestPlanner, CreateMultipleNode) { // Test CREATE (n), (m) AstTreeStorage storage; QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n")), PATTERN(NODE("m"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateNode()); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateNode()); } -TEST(TestLogicalPlanner, CreateNodeExpandNode) { +TYPED_TEST(TestPlanner, CreateNodeExpandNode) { // Test CREATE (n) -[r :rel]-> (m), (l) AstTreeStorage storage; database::SingleNode db; @@ -394,11 +456,11 @@ TEST(TestLogicalPlanner, CreateNodeExpandNode) { QUERY(SINGLE_QUERY(CREATE( PATTERN(NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m")), PATTERN(NODE("l"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), - ExpectCreateNode()); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), + ExpectCreateNode()); } -TEST(TestLogicalPlanner, CreateNamedPattern) { +TYPED_TEST(TestPlanner, CreateNamedPattern) { // Test CREATE p = (n) -[r :rel]-> (m) AstTreeStorage storage; database::SingleNode db; @@ -406,11 +468,11 @@ TEST(TestLogicalPlanner, CreateNamedPattern) { auto relationship = dba.EdgeType("rel"); QUERY(SINGLE_QUERY(CREATE(NAMED_PATTERN( "p", NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), - ExpectConstructNamedPath()); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), + ExpectConstructNamedPath()); } -TEST(TestLogicalPlanner, MatchCreateExpand) { +TYPED_TEST(TestPlanner, MatchCreateExpand) { // Test MATCH (n) CREATE (n) -[r :rel1]-> (m) AstTreeStorage storage; database::SingleNode db; @@ -420,20 +482,20 @@ TEST(TestLogicalPlanner, MatchCreateExpand) { MATCH(PATTERN(NODE("n"))), CREATE(PATTERN(NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); - CheckPlan(storage, ExpectScanAll(), ExpectCreateExpand()); + CheckPlan(storage, ExpectScanAll(), ExpectCreateExpand()); } -TEST(TestLogicalPlanner, MatchLabeledNodes) { +TYPED_TEST(TestPlanner, MatchLabeledNodes) { // Test MATCH (n :label) RETURN n AstTreeStorage storage; database::SingleNode db; database::GraphDbAccessor dba(db); auto label = dba.Label("label"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), RETURN("n"))); - CheckPlan(storage, ExpectScanAllByLabel(), ExpectProduce()); + CheckPlan(storage, ExpectScanAllByLabel(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchPathReturn) { +TYPED_TEST(TestPlanner, MatchPathReturn) { // Test MATCH (n) -[r :relationship]- (m) RETURN n AstTreeStorage storage; database::SingleNode db; @@ -443,10 +505,11 @@ TEST(TestLogicalPlanner, MatchPathReturn) { MATCH(PATTERN(NODE("n"), EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchNamedPatternReturn) { +TYPED_TEST(TestPlanner, MatchNamedPatternReturn) { // Test MATCH p = (n) -[r :relationship]- (m) RETURN p AstTreeStorage storage; database::SingleNode db; @@ -457,11 +520,11 @@ TEST(TestLogicalPlanner, MatchNamedPatternReturn) { EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectConstructNamedPath(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectConstructNamedPath(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchNamedPatternWithPredicateReturn) { +TYPED_TEST(TestPlanner, MatchNamedPatternWithPredicateReturn) { // Test MATCH p = (n) -[r :relationship]- (m) RETURN p AstTreeStorage storage; database::SingleNode db; @@ -472,11 +535,12 @@ TEST(TestLogicalPlanner, MatchNamedPatternWithPredicateReturn) { EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), WHERE(EQ(LITERAL(2), IDENT("p"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectConstructNamedPath(), ExpectFilter(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectConstructNamedPath(), ExpectFilter(), + ExpectProduce()); } -TEST(TestLogicalPlanner, OptionalMatchNamedPatternReturn) { +TYPED_TEST(TestPlanner, OptionalMatchNamedPatternReturn) { // Test OPTIONAL MATCH p = (n) -[r]- (m) RETURN p database::SingleNode db; database::GraphDbAccessor dba(db); @@ -491,9 +555,6 @@ TEST(TestLogicalPlanner, OptionalMatchNamedPatternReturn) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - std::list optional{new ExpectScanAll(), new ExpectExpand(), new ExpectConstructNamedPath()}; auto get_symbol = [&symbol_table](const auto *ast_node) { @@ -501,11 +562,12 @@ TEST(TestLogicalPlanner, OptionalMatchNamedPatternReturn) { }; std::vector optional_symbols{get_symbol(pattern), get_symbol(node_n), get_symbol(edge), get_symbol(node_m)}; - CheckPlan(*plan, symbol_table, ExpectOptional(optional_symbols, optional), - ExpectProduce()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, + ExpectOptional(optional_symbols, optional), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWhereReturn) { +TYPED_TEST(TestPlanner, MatchWhereReturn) { // Test MATCH (n) WHERE n.property < 42 RETURN n AstTreeStorage storage; database::SingleNode db; @@ -514,17 +576,18 @@ TEST(TestLogicalPlanner, MatchWhereReturn) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchDelete) { +TYPED_TEST(TestPlanner, MatchDelete) { // Test MATCH (n) DELETE n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")))); - CheckPlan(storage, ExpectScanAll(), ExpectDelete()); + CheckPlan(storage, ExpectScanAll(), ExpectDelete()); } -TEST(TestLogicalPlanner, MatchNodeSet) { +TYPED_TEST(TestPlanner, MatchNodeSet) { // Test MATCH (n) SET n.prop = 42, n = n, n :label AstTreeStorage storage; database::SingleNode db; @@ -534,11 +597,11 @@ TEST(TestLogicalPlanner, MatchNodeSet) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET(PROPERTY_LOOKUP("n", prop), LITERAL(42)), SET("n", IDENT("n")), SET("n", {label}))); - CheckPlan(storage, ExpectScanAll(), ExpectSetProperty(), - ExpectSetProperties(), ExpectSetLabels()); + CheckPlan(storage, ExpectScanAll(), ExpectSetProperty(), + ExpectSetProperties(), ExpectSetLabels()); } -TEST(TestLogicalPlanner, MatchRemove) { +TYPED_TEST(TestPlanner, MatchRemove) { // Test MATCH (n) REMOVE n.prop REMOVE n :label AstTreeStorage storage; database::SingleNode db; @@ -547,11 +610,11 @@ TEST(TestLogicalPlanner, MatchRemove) { auto label = dba.Label("label"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE(PROPERTY_LOOKUP("n", prop)), REMOVE("n", {label}))); - CheckPlan(storage, ExpectScanAll(), ExpectRemoveProperty(), - ExpectRemoveLabels()); + CheckPlan(storage, ExpectScanAll(), ExpectRemoveProperty(), + ExpectRemoveLabels()); } -TEST(TestLogicalPlanner, MatchMultiPattern) { +TYPED_TEST(TestPlanner, MatchMultiPattern) { // Test MATCH (n) -[r]- (m), (j) -[e]- (i) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), @@ -559,12 +622,12 @@ TEST(TestLogicalPlanner, MatchMultiPattern) { RETURN("n"))); // We expect the expansions after the first to have a uniqueness filter in a // single MATCH clause. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), - ExpectExpand(), ExpectExpandUniquenessFilter(), - ExpectProduce()); + CheckPlan( + storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), ExpectExpand(), + ExpectExpandUniquenessFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchMultiPatternSameStart) { +TYPED_TEST(TestPlanner, MatchMultiPatternSameStart) { // Test MATCH (n), (n) -[e]- (m) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY( @@ -572,10 +635,11 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameStart) { RETURN("n"))); // We expect the second pattern to generate only an Expand, since another // ScanAll would be redundant. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) { +TYPED_TEST(TestPlanner, MatchMultiPatternSameExpandStart) { // Test MATCH (n) -[r]- (m), (m) -[e]- (l) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), @@ -584,11 +648,12 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) { // We expect the second pattern to generate only an Expand. Another // ScanAll would be redundant, as it would generate the nodes obtained from // expansion. Additionally, a uniqueness filter is expected. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), - ExpectExpandUniquenessFilter(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpandUniquenessFilter(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MultiMatch) { +TYPED_TEST(TestPlanner, MultiMatch) { // Test MATCH (n) -[r]- (m) MATCH (j) -[e]- (i) -[f]- (h) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY( @@ -597,12 +662,13 @@ TEST(TestLogicalPlanner, MultiMatch) { RETURN("n"))); // Multiple MATCH clauses form a Cartesian product, so the uniqueness should // not cross MATCH boundaries. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), - ExpectExpand(), ExpectExpand(), - ExpectExpandUniquenessFilter(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpandUniquenessFilter(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MultiMatchSameStart) { +TYPED_TEST(TestPlanner, MultiMatchSameStart) { // Test MATCH (n) MATCH (n) -[r]- (m) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), @@ -610,19 +676,21 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) { RETURN("n"))); // Similar to MatchMultiPatternSameStart, we expect only Expand from second // MATCH clause. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWithReturn) { +TYPED_TEST(TestPlanner, MatchWithReturn) { // Test MATCH (old) WITH old AS new RETURN new AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH("old", AS("new")), RETURN("new"))); // No accumulation since we only do reads. - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWithWhereReturn) { +TYPED_TEST(TestPlanner, MatchWithWhereReturn) { // Test MATCH (old) WITH old AS new WHERE new.prop < 42 RETURN new database::SingleNode db; database::GraphDbAccessor dba(db); @@ -632,11 +700,11 @@ TEST(TestLogicalPlanner, MatchWithWhereReturn) { WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))), RETURN("new"))); // No accumulation since we only do reads. - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectFilter(), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, CreateMultiExpand) { +TYPED_TEST(TestPlanner, CreateMultiExpand) { // Test CREATE (n) -[r :r]-> (m), (n) - [p :p]-> (l) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -646,11 +714,11 @@ TEST(TestLogicalPlanner, CreateMultiExpand) { QUERY(SINGLE_QUERY( CREATE(PATTERN(NODE("n"), EDGE("r", Direction::OUT, {r}), NODE("m")), PATTERN(NODE("n"), EDGE("p", Direction::OUT, {p}), NODE("l"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), - ExpectCreateExpand()); + CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), + ExpectCreateExpand()); } -TEST(TestLogicalPlanner, MatchWithSumWhereReturn) { +TYPED_TEST(TestPlanner, MatchWithSumWhereReturn) { // Test MATCH (n) WITH SUM(n.prop) + 42 AS sum WHERE sum < 42 // RETURN sum AS result database::SingleNode db; @@ -663,11 +731,11 @@ TEST(TestLogicalPlanner, MatchWithSumWhereReturn) { MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")), WHERE(LESS(IDENT("sum"), LITERAL(42))), RETURN("sum", AS("result")))); auto aggr = ExpectAggregate({sum}, {literal}); - CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchReturnSum) { +TYPED_TEST(TestPlanner, MatchReturnSum) { // Test MATCH (n) RETURN SUM(n.prop1) AS sum, n.prop2 AS group database::SingleNode db; database::GraphDbAccessor dba(db); @@ -679,10 +747,10 @@ TEST(TestLogicalPlanner, MatchReturnSum) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group")))); auto aggr = ExpectAggregate({sum}, {n_prop2}); - CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, CreateWithSum) { +TYPED_TEST(TestPlanner, CreateWithSum) { // Test CREATE (n) WITH SUM(n.prop) AS sum database::SingleNode db; database::GraphDbAccessor dba(db); @@ -699,15 +767,14 @@ TEST(TestLogicalPlanner, CreateWithSum) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We expect both the accumulation and aggregation because the part before // WITH updates the database. - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWithCreate) { +TYPED_TEST(TestPlanner, MatchWithCreate) { // Test MATCH (n) WITH n AS a CREATE (a) -[r :r]-> (b) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -717,19 +784,20 @@ TEST(TestLogicalPlanner, MatchWithCreate) { MATCH(PATTERN(NODE("n"))), WITH("n", AS("a")), CREATE( PATTERN(NODE("a"), EDGE("r", Direction::OUT, {r_type}), NODE("b"))))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), + ExpectCreateExpand()); } -TEST(TestLogicalPlanner, MatchReturnSkipLimit) { +TYPED_TEST(TestPlanner, MatchReturnSkipLimit) { // Test MATCH (n) RETURN n SKIP 2 LIMIT 1 AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n", SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), - ExpectLimit()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), + ExpectLimit()); } -TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { +TYPED_TEST(TestPlanner, CreateWithSkipReturnLimit) { // Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1 AstTreeStorage storage; auto ident_n = IDENT("n"); @@ -744,18 +812,17 @@ TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // Since we have a write query, we need to have Accumulate. This is a bit // different than Neo4j 3.0, which optimizes WITH followed by RETURN as a // single RETURN clause and then moves Skip and Limit before Accumulate. This // causes different behaviour. A newer version of Neo4j does the same thing as // us here (but who knows if they change it again). - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce(), - ExpectSkip(), ExpectProduce(), ExpectLimit()); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, + ExpectProduce(), ExpectSkip(), ExpectProduce(), ExpectLimit()); } -TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { +TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) { // Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1 database::SingleNode db; database::GraphDbAccessor dba(db); @@ -773,13 +840,12 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), - ExpectSkip(), ExpectLimit()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, + ExpectProduce(), ExpectSkip(), ExpectLimit()); } -TEST(TestLogicalPlanner, MatchReturnOrderBy) { +TYPED_TEST(TestPlanner, MatchReturnOrderBy) { // Test MATCH (n) RETURN n ORDER BY n.prop database::SingleNode db; database::GraphDbAccessor dba(db); @@ -787,10 +853,11 @@ TEST(TestLogicalPlanner, MatchReturnOrderBy) { AstTreeStorage storage; auto ret = RETURN("n", ORDER_BY(PROPERTY_LOOKUP("n", prop))); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret)); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectOrderBy()); + CheckPlan(storage, ExpectScanAll(), ExpectProduce(), + ExpectOrderBy()); } -TEST(TestLogicalPlanner, CreateWithOrderByWhere) { +TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { // Test CREATE (n) -[r :r]-> (m) // WITH n AS new ORDER BY new.prop, r.prop WHERE m.prop < 42 database::SingleNode db; @@ -818,13 +885,13 @@ TEST(TestLogicalPlanner, CreateWithOrderByWhere) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc, - ExpectProduce(), ExpectOrderBy(), ExpectFilter()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), + ExpectCreateExpand(), acc, ExpectProduce(), ExpectOrderBy(), + ExpectFilter()); } -TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) { +TYPED_TEST(TestPlanner, ReturnAddSumCountOrderBy) { // Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result AstTreeStorage storage; auto sum = SUM(LITERAL(1)); @@ -832,10 +899,10 @@ TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) { QUERY(SINGLE_QUERY( RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result"))))); auto aggr = ExpectAggregate({sum, count}, {}); - CheckPlan(storage, aggr, ExpectProduce(), ExpectOrderBy()); + CheckPlan(storage, aggr, ExpectProduce(), ExpectOrderBy()); } -TEST(TestLogicalPlanner, MatchMerge) { +TYPED_TEST(TestPlanner, MatchMerge) { // Test MATCH (n) MERGE (n) -[r :r]- (m) // ON MATCH SET n.prop = 42 ON CREATE SET m = n // RETURN n AS n @@ -862,9 +929,8 @@ TEST(TestLogicalPlanner, MatchMerge) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectScanAll(), + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectMerge(on_match, on_create), acc, ExpectProduce()); for (auto &op : on_match) delete op; on_match.clear(); @@ -872,7 +938,7 @@ TEST(TestLogicalPlanner, MatchMerge) { on_create.clear(); } -TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { +TYPED_TEST(TestPlanner, MatchOptionalMatchWhereReturn) { // Test MATCH (n) OPTIONAL MATCH (n) -[r]- (m) WHERE m.prop < 42 RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -884,29 +950,30 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { RETURN("r"))); std::list optional{new ExpectScanAll(), new ExpectExpand(), new ExpectFilter()}; - CheckPlan(storage, ExpectScanAll(), ExpectOptional(optional), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectOptional(optional), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchUnwindReturn) { +TYPED_TEST(TestPlanner, MatchUnwindReturn) { // Test MATCH (n) UNWIND [1,2,3] AS x RETURN n, x AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("x")), RETURN("n", "x"))); - CheckPlan(storage, ExpectScanAll(), ExpectUnwind(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectUnwind(), + ExpectProduce()); } -TEST(TestLogicalPlanner, ReturnDistinctOrderBySkipLimit) { +TYPED_TEST(TestPlanner, ReturnDistinctOrderBySkipLimit) { // Test RETURN DISTINCT 1 ORDER BY 1 SKIP 1 LIMIT 1 AstTreeStorage storage; QUERY(SINGLE_QUERY(RETURN_DISTINCT(LITERAL(1), AS("1"), ORDER_BY(LITERAL(1)), SKIP(LITERAL(1)), LIMIT(LITERAL(1))))); - CheckPlan(storage, ExpectProduce(), ExpectDistinct(), ExpectOrderBy(), - ExpectSkip(), ExpectLimit()); + CheckPlan(storage, ExpectProduce(), ExpectDistinct(), + ExpectOrderBy(), ExpectSkip(), ExpectLimit()); } -TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) { +TYPED_TEST(TestPlanner, CreateWithDistinctSumWhereReturn) { // Test CREATE (n) WITH DISTINCT SUM(n.prop) AS s WHERE s < 42 RETURN s database::SingleNode db; database::GraphDbAccessor dba(db); @@ -924,13 +991,12 @@ TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), - ExpectDistinct(), ExpectFilter(), ExpectProduce()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, + ExpectProduce(), ExpectDistinct(), ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchCrossReferenceVariable) { +TYPED_TEST(TestPlanner, MatchCrossReferenceVariable) { // Test MATCH (n {prop: m.prop}), (m {prop: n.prop}) RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -945,11 +1011,11 @@ TEST(TestLogicalPlanner, MatchCrossReferenceVariable) { QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n), PATTERN(node_m)), RETURN("n"))); // We expect both ScanAll to come before filters (2 are joined into one), // because they need to populate the symbol values. - CheckPlan(storage, ExpectScanAll(), ExpectScanAll(), ExpectFilter(), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectScanAll(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWhereBeforeExpand) { +TYPED_TEST(TestPlanner, MatchWhereBeforeExpand) { // Test MATCH (n) -[r]- (m) WHERE n.prop < 42 RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -959,11 +1025,11 @@ TEST(TestLogicalPlanner, MatchWhereBeforeExpand) { WHERE(LESS(PROPERTY_LOOKUP("n", prop), LITERAL(42))), RETURN("n"))); // We expect Fitler to come immediately after ScanAll, since it only uses `n`. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MultiMatchWhere) { +TYPED_TEST(TestPlanner, MultiMatchWhere) { // Test MATCH (n) -[r]- (m) MATCH (l) WHERE n.prop < 42 RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -975,11 +1041,11 @@ TEST(TestLogicalPlanner, MultiMatchWhere) { RETURN("n"))); // Even though WHERE is in the second MATCH clause, we expect Filter to come // before second ScanAll, since it only uses the value from first ScanAll. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectScanAll(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectScanAll(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchOptionalMatchWhere) { +TYPED_TEST(TestPlanner, MatchOptionalMatchWhere) { // Test MATCH (n) -[r]- (m) OPTIONAL MATCH (l) WHERE n.prop < 42 RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -993,11 +1059,11 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhere) { // first ScanAll, it must remain part of the Optional. It should come before // optional ScanAll. std::list optional{new ExpectFilter(), new ExpectScanAll()}; - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectOptional(optional), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpand(), + ExpectOptional(optional), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchReturnAsterisk) { +TYPED_TEST(TestPlanner, MatchReturnAsterisk) { // Test MATCH (n) -[e]- (m) RETURN *, m.prop database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1012,19 +1078,18 @@ TEST(TestLogicalPlanner, MatchReturnAsterisk) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(), + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), ExpectProduce()); std::vector output_names; - for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) { + for (const auto &output_symbol : planner.plan().OutputSymbols(symbol_table)) { output_names.emplace_back(output_symbol.name()); } std::vector expected_names{"e", "m", "n", "m.prop"}; EXPECT_EQ(output_names, expected_names); } -TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { +TYPED_TEST(TestPlanner, MatchReturnAsteriskSum) { // Test MATCH (n) RETURN *, SUM(n.prop) AS s database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1039,9 +1104,8 @@ TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - auto *produce = dynamic_cast(plan.get()); + TypeParam planner(single_query_parts, planning_context); + auto *produce = dynamic_cast(&planner.plan()); ASSERT_TRUE(produce); const auto &named_expressions = produce->named_expressions(); ASSERT_EQ(named_expressions.size(), 2); @@ -1049,16 +1113,17 @@ TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { dynamic_cast(named_expressions[0]->expression_); ASSERT_TRUE(expanded_ident); auto aggr = ExpectAggregate({sum}, {expanded_ident}); - CheckPlan(*plan, symbol_table, ExpectScanAll(), aggr, ExpectProduce()); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), aggr, + ExpectProduce()); std::vector output_names; - for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) { + for (const auto &output_symbol : planner.plan().OutputSymbols(symbol_table)) { output_names.emplace_back(output_symbol.name()); } std::vector expected_names{"n", "s"}; EXPECT_EQ(output_names, expected_names); } -TEST(TestLogicalPlanner, UnwindMergeNodeProperty) { +TYPED_TEST(TestPlanner, UnwindMergeNodeProperty) { // Test UNWIND [1] AS i MERGE (n {prop: i}) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1069,22 +1134,23 @@ TEST(TestLogicalPlanner, UnwindMergeNodeProperty) { SINGLE_QUERY(UNWIND(LIST(LITERAL(1)), AS("i")), MERGE(PATTERN(node_n)))); std::list on_match{new ExpectScanAll(), new ExpectFilter()}; std::list on_create{new ExpectCreateNode()}; - CheckPlan(storage, ExpectUnwind(), ExpectMerge(on_match, on_create)); + CheckPlan(storage, ExpectUnwind(), + ExpectMerge(on_match, on_create)); for (auto &op : on_match) delete op; for (auto &op : on_create) delete op; } -TEST(TestLogicalPlanner, MultipleOptionalMatchReturn) { +TYPED_TEST(TestPlanner, MultipleOptionalMatchReturn) { // Test OPTIONAL MATCH (n) OPTIONAL MATCH (m) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(OPTIONAL_MATCH(PATTERN(NODE("n"))), OPTIONAL_MATCH(PATTERN(NODE("m"))), RETURN("n"))); std::list optional{new ExpectScanAll()}; - CheckPlan(storage, ExpectOptional(optional), ExpectOptional(optional), - ExpectProduce()); + CheckPlan(storage, ExpectOptional(optional), + ExpectOptional(optional), ExpectProduce()); } -TEST(TestLogicalPlanner, FunctionAggregationReturn) { +TYPED_TEST(TestPlanner, FunctionAggregationReturn) { // Test RETURN sqrt(SUM(2)) AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1092,17 +1158,17 @@ TEST(TestLogicalPlanner, FunctionAggregationReturn) { QUERY(SINGLE_QUERY( RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, FunctionWithoutArguments) { +TYPED_TEST(TestPlanner, FunctionWithoutArguments) { // Test RETURN pi() AS pi AstTreeStorage storage; QUERY(SINGLE_QUERY(RETURN(FN("pi"), AS("pi")))); - CheckPlan(storage, ExpectProduce()); + CheckPlan(storage, ExpectProduce()); } -TEST(TestLogicalPlanner, ListLiteralAggregationReturn) { +TYPED_TEST(TestPlanner, ListLiteralAggregationReturn) { // Test RETURN [SUM(2)] AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1110,10 +1176,10 @@ TEST(TestLogicalPlanner, ListLiteralAggregationReturn) { QUERY(SINGLE_QUERY( RETURN(LIST(sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MapLiteralAggregationReturn) { +TYPED_TEST(TestPlanner, MapLiteralAggregationReturn) { // Test RETURN {sum: SUM(2)} AS result, 42 AS group_by AstTreeStorage storage; database::SingleNode db; @@ -1123,10 +1189,10 @@ TEST(TestLogicalPlanner, MapLiteralAggregationReturn) { QUERY(SINGLE_QUERY(RETURN(MAP({PROPERTY_PAIR("sum"), sum}), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, EmptyListIndexAggregation) { +TYPED_TEST(TestPlanner, EmptyListIndexAggregation) { // Test RETURN [][SUM(2)] AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1139,10 +1205,10 @@ TEST(TestLogicalPlanner, EmptyListIndexAggregation) { // sub-expression of a binary operator which contains an aggregation. This is // similar to grouping by '1' in `RETURN 1 + SUM(2)`. auto aggr = ExpectAggregate({sum}, {empty_list, group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, ListSliceAggregationReturn) { +TYPED_TEST(TestPlanner, ListSliceAggregationReturn) { // Test RETURN [1, 2][0..SUM(2)] AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1153,20 +1219,20 @@ TEST(TestLogicalPlanner, ListSliceAggregationReturn) { // Similarly to EmptyListIndexAggregation test, we expect grouping by list and // '42', because slicing is an operator. auto aggr = ExpectAggregate({sum}, {list, group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, ListWithAggregationAndGroupBy) { +TYPED_TEST(TestPlanner, ListWithAggregationAndGroupBy) { // Test RETURN [sum(2), 42] AstTreeStorage storage; auto sum = SUM(LITERAL(2)); auto group_by_literal = LITERAL(42); QUERY(SINGLE_QUERY(RETURN(LIST(sum, group_by_literal), AS("result")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, AggregatonWithListWithAggregationAndGroupBy) { +TYPED_TEST(TestPlanner, AggregatonWithListWithAggregationAndGroupBy) { // Test RETURN sum(2), [sum(3), 42] AstTreeStorage storage; auto sum2 = SUM(LITERAL(2)); @@ -1175,10 +1241,10 @@ TEST(TestLogicalPlanner, AggregatonWithListWithAggregationAndGroupBy) { QUERY(SINGLE_QUERY( RETURN(sum2, AS("sum2"), LIST(sum3, group_by_literal), AS("list")))); auto aggr = ExpectAggregate({sum2, sum3}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MapWithAggregationAndGroupBy) { +TYPED_TEST(TestPlanner, MapWithAggregationAndGroupBy) { // Test RETURN {lit: 42, sum: sum(2)} database::SingleNode db; AstTreeStorage storage; @@ -1188,10 +1254,10 @@ TEST(TestLogicalPlanner, MapWithAggregationAndGroupBy) { {PROPERTY_PAIR("lit"), group_by_literal}), AS("result")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, CreateIndex) { +TYPED_TEST(TestPlanner, CreateIndex) { // Test CREATE INDEX ON :Label(property) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1199,10 +1265,10 @@ TEST(TestLogicalPlanner, CreateIndex) { auto property = dba.Property("property"); AstTreeStorage storage; QUERY(SINGLE_QUERY(CREATE_INDEX_ON(label, property))); - CheckPlan(storage, ExpectCreateIndex(label, property)); + CheckPlan(storage, ExpectCreateIndex(label, property)); } -TEST(TestLogicalPlanner, AtomIndexedLabelProperty) { +TYPED_TEST(TestPlanner, AtomIndexedLabelProperty) { // Test MATCH (n :label {property: 42, not_indexed: 0}) RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1227,16 +1293,14 @@ TEST(TestLogicalPlanner, AtomIndexedLabelProperty) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, AtomPropertyWhereLabelIndexing) { +TYPED_TEST(TestPlanner, AtomPropertyWhereLabelIndexing) { // Test MATCH (n {property: 42}) WHERE n.not_indexed AND n:label RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1261,16 +1325,14 @@ TEST(TestLogicalPlanner, AtomPropertyWhereLabelIndexing) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, WhereIndexedLabelProperty) { +TYPED_TEST(TestPlanner, WhereIndexedLabelProperty) { // Test MATCH (n :label) WHERE n.property = 42 RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1289,15 +1351,14 @@ TEST(TestLogicalPlanner, WhereIndexedLabelProperty) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectProduce()); } } -TEST(TestLogicalPlanner, BestPropertyIndexed) { +TYPED_TEST(TestPlanner, BestPropertyIndexed) { // Test MATCH (n :label) WHERE n.property = 1 AND n.better = 42 RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1329,15 +1390,14 @@ TEST(TestLogicalPlanner, BestPropertyIndexed) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, better, lit_42), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, MultiPropertyIndexScan) { +TYPED_TEST(TestPlanner, MultiPropertyIndexScan) { // Test MATCH (n :label1), (m :label2) WHERE n.prop1 = 1 AND m.prop2 = 2 // RETURN n, m database::SingleNode db; @@ -1361,15 +1421,14 @@ TEST(TestLogicalPlanner, MultiPropertyIndexScan) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label1, prop1, lit_1), ExpectScanAllByLabelPropertyValue(label2, prop2, lit_2), ExpectProduce()); } -TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { +TYPED_TEST(TestPlanner, WhereIndexedLabelPropertyRange) { // Test MATCH (n :label) WHERE n.property REL_OP 42 RETURN n // REL_OP is one of: `<`, `<=`, `>`, `>=` database::SingleNode db; @@ -1380,8 +1439,9 @@ TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { AstTreeStorage storage; auto lit_42 = LITERAL(42); auto n_prop = PROPERTY_LOOKUP("n", property); - auto check_planned_range = [&label, &property, &dba]( - const auto &rel_expr, auto lower_bound, auto upper_bound) { + auto check_planned_range = [&label, &property, &dba](const auto &rel_expr, + auto lower_bound, + auto upper_bound) { // Shadow the first storage, so that the query is created in this one. AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), WHERE(rel_expr), @@ -1391,9 +1451,8 @@ TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyRange(label, property, lower_bound, upper_bound), ExpectProduce()); @@ -1424,7 +1483,7 @@ TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { } } -TEST(TestLogicalPlanner, UnableToUsePropertyIndex) { +TYPED_TEST(TestPlanner, UnableToUsePropertyIndex) { // Test MATCH (n: label) WHERE n.property = n.property RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1443,16 +1502,15 @@ TEST(TestLogicalPlanner, UnableToUsePropertyIndex) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We can only get ScanAllByLabelIndex, because we are comparing properties // with those on the same node. - CheckPlan(*plan, symbol_table, ExpectScanAllByLabel(), ExpectFilter(), - ExpectProduce()); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabel(), + ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, SecondPropertyIndex) { +TYPED_TEST(TestPlanner, SecondPropertyIndex) { // Test MATCH (n :label), (m :label) WHERE m.property = n.property RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1472,44 +1530,45 @@ TEST(TestLogicalPlanner, SecondPropertyIndex) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); CheckPlan( - *plan, symbol_table, ExpectScanAllByLabel(), + planner.plan(), symbol_table, ExpectScanAllByLabel(), // Note: We are scanning for m, therefore property should equal n_prop. ExpectScanAllByLabelPropertyValue(label, property, n_prop), ExpectProduce()); } } -TEST(TestLogicalPlanner, ReturnSumGroupByAll) { +TYPED_TEST(TestPlanner, ReturnSumGroupByAll) { // Test RETURN sum([1,2,3]), all(x in [1] where x = 1) AstTreeStorage storage; auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3))); auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1)))); QUERY(SINGLE_QUERY(RETURN(sum, AS("sum"), all, AS("all")))); auto aggr = ExpectAggregate({sum}, {all}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariable) { +TYPED_TEST(TestPlanner, MatchExpandVariable) { // Test MATCH (n) -[r *..3]-> (m) RETURN r AstTreeStorage storage; auto edge = EDGE_VARIABLE("r"); edge->upper_bound_ = LITERAL(3); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableNoBounds) { +TYPED_TEST(TestPlanner, MatchExpandVariableNoBounds) { // Test MATCH (n) -[r *]-> (m) RETURN r AstTreeStorage storage; auto edge = EDGE_VARIABLE("r"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) { +TYPED_TEST(TestPlanner, MatchExpandVariableInlinedFilter) { // Test MATCH (n) -[r :type * {prop: 42}]-> (m) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1519,12 +1578,13 @@ TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) { auto edge = EDGE_VARIABLE("r", Direction::BOTH, {type}); edge->properties_[prop] = LITERAL(42); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), - ExpectExpandVariable(), // Filter is both inlined and post-expand - ExpectFilter(), ExpectProduce()); + CheckPlan( + storage, ExpectScanAll(), + ExpectExpandVariable(), // Filter is both inlined and post-expand + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) { +TYPED_TEST(TestPlanner, MatchExpandVariableNotInlinedFilter) { // Test MATCH (n) -[r :type * {prop: m.prop}]-> (m) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1534,11 +1594,11 @@ TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) { auto edge = EDGE_VARIABLE("r", Direction::BOTH, {type}); edge->properties_[prop] = EQ(PROPERTY_LOOKUP("m", prop), LITERAL(42)); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectFilter(), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, UnwindMatchVariable) { +TYPED_TEST(TestPlanner, UnwindMatchVariable) { // Test UNWIND [1,2,3] AS depth MATCH (n) -[r*d]-> (m) RETURN r AstTreeStorage storage; auto edge = EDGE_VARIABLE("r", Direction::OUT); @@ -1546,11 +1606,11 @@ TEST(TestLogicalPlanner, UnwindMatchVariable) { edge->upper_bound_ = IDENT("d"); QUERY(SINGLE_QUERY(UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("d")), MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectUnwind(), ExpectScanAll(), ExpectExpandVariable(), - ExpectProduce()); + CheckPlan(storage, ExpectUnwind(), ExpectScanAll(), + ExpectExpandVariable(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchBreadthFirst) { +TYPED_TEST(TestPlanner, MatchBreadthFirst) { // Test MATCH (n) -[r:type *..10 (r, n|n)]-> (m) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1564,11 +1624,11 @@ TEST(TestLogicalPlanner, MatchBreadthFirst) { bfs->filter_expression_ = IDENT("n"); bfs->upper_bound_ = LITERAL(10); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandBreadthFirst(), - ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectExpandBreadthFirst(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchDoubleScanToExpandExisting) { +TYPED_TEST(TestPlanner, MatchDoubleScanToExpandExisting) { // Test MATCH (n) -[r]- (m :label) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1581,15 +1641,14 @@ TEST(TestLogicalPlanner, MatchDoubleScanToExpandExisting) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We expect 2x ScanAll and then Expand, since we are guessing that is // faster (due to low label index vertex count). - CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectScanAllByLabel(), - ExpectExpand(), ExpectProduce()); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), + ExpectScanAllByLabel(), ExpectExpand(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchScanToExpand) { +TYPED_TEST(TestPlanner, MatchScanToExpand) { // Test MATCH (n) -[r]- (m :label {property: 1}) RETURN r database::SingleNode db; auto label = database::GraphDbAccessor(db).Label("label"); @@ -1619,16 +1678,15 @@ TEST(TestLogicalPlanner, MatchScanToExpand) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We expect 1x ScanAllByLabel and then Expand, since we are guessing that // is faster (due to high label index vertex count). - CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(), + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, MatchWhereAndSplit) { +TYPED_TEST(TestPlanner, MatchWhereAndSplit) { // Test MATCH (n) -[r]- (m) WHERE n.prop AND r.prop RETURN m database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1639,11 +1697,11 @@ TEST(TestLogicalPlanner, MatchWhereAndSplit) { WHERE(AND(PROPERTY_LOOKUP("n", prop), PROPERTY_LOOKUP("r", prop))), RETURN("m"))); // We expect `n.prop` filter right after scanning `n`. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectFilter(), ExpectProduce()); + CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, ReturnAsteriskOmitsLambdaSymbols) { +TYPED_TEST(TestPlanner, ReturnAsteriskOmitsLambdaSymbols) { // Test MATCH (n) -[r* (ie, in | true)]- (m) RETURN * database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1660,9 +1718,8 @@ TEST(TestLogicalPlanner, ReturnAsteriskOmitsLambdaSymbols) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery( - single_query_parts, planning_context); - auto *produce = dynamic_cast(plan.get()); + TypeParam planner(single_query_parts, planning_context); + auto *produce = dynamic_cast(&planner.plan()); ASSERT_TRUE(produce); std::vector outputs; for (const auto &output_symbol : produce->OutputSymbols(symbol_table)) {