diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index 16c48bf80..a9ee568ef 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -7,6 +7,18 @@ namespace query { +// Id for boost's archive get_helper needs to be unique among all ids. If it +// isn't unique, then different instances (even across different types!) will +// replace the previous helper. The documentation recommends to take an address +// of a function, which should be unique in the produced binary. +// It might seem logical to take an address of a AstTreeStorage constructor, but +// according to c++ standard, the constructor function need not have an address. +// Additionally, pointers to member functions are not required to contain the +// address of the function +// (https://isocpp.org/wiki/faq/pointers-to-members#addr-of-memfn). So, to be +// safe, use a regular top-level function. +void *const AstTreeStorage::kHelperId = (void *)CloneReturnBody; + AstTreeStorage::AstTreeStorage() { storage_.emplace_back(new Query(next_uid_++)); } @@ -15,14 +27,6 @@ Query *AstTreeStorage::query() const { return dynamic_cast<Query *>(storage_[0].get()); } -int AstTreeStorage::MaximumStorageUid() const { - int max_uid = -1; - for (const auto &tree : storage_) { - max_uid = std::max(max_uid, tree->uid()); - } - return max_uid; -} - ReturnBody CloneReturnBody(AstTreeStorage &storage, const ReturnBody &body) { ReturnBody new_body; new_body.distinct = body.distinct; diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 00f5a998f..589a6d0a8 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1,6 +1,5 @@ #pragma once -#include <map> #include <memory> #include <unordered_map> #include <vector> @@ -9,13 +8,10 @@ #include "boost/serialization/split_member.hpp" #include "boost/serialization/string.hpp" #include "boost/serialization/vector.hpp" -#include "glog/logging.h" -#include "database/graph_db.hpp" #include "query/frontend/ast/ast_visitor.hpp" #include "query/frontend/semantic/symbol.hpp" #include "query/interpret/awesome_memgraph_functions.hpp" -#include "query/parameters.hpp" #include "query/typed_value.hpp" #include "storage/types.hpp" #include "utils/serialization.hpp" @@ -35,6 +31,10 @@ struct hash<std::pair<std::string, storage::Property>> { }; } // namespace std +namespace database { +class GraphDbAccessor; +} + namespace query { #define CLONE_BINARY_EXPRESSION \ @@ -60,6 +60,7 @@ namespace query { class Tree; + // It would be better to call this AstTree, but we already have a class Tree, // which could be renamed to Node or AstTreeNode, but we also have a class // called NodeAtom... @@ -82,14 +83,16 @@ class AstTreeStorage { Query *query() const; + /// Id for using get_helper<AstTreeStorage> in boost archives. + static void * const kHelperId; + /// Load an Ast Node into this storage. template <class TArchive, class TNode> void Load(TArchive &ar, TNode &node) { - auto &tmp_ast = ar.template get_helper<AstTreeStorage>(); - tmp_ast.storage_ = std::move(storage_); + auto &tmp_ast = ar.template get_helper<AstTreeStorage>(kHelperId); + std::swap(*this, tmp_ast); ar >> node; - storage_ = std::move(tmp_ast.storage_); - next_uid_ = MaximumStorageUid() + 1; + std::swap(*this, tmp_ast); } /// Load a Query into this storage. @@ -102,8 +105,6 @@ class AstTreeStorage { int next_uid_ = 0; std::vector<std::unique_ptr<Tree>> storage_; - int MaximumStorageUid() const; - template <class TArchive, class TNode> friend void LoadPointer(TArchive &ar, TNode *&node); }; @@ -117,7 +118,8 @@ template <class TArchive, class TNode> void LoadPointer(TArchive &ar, TNode *&node) { ar >> node; if (node) { - auto &ast_storage = ar.template get_helper<AstTreeStorage>(); + auto &ast_storage = + ar.template get_helper<AstTreeStorage>(AstTreeStorage::kHelperId); auto found = std::find_if(ast_storage.storage_.begin(), ast_storage.storage_.end(), [&](const auto &n) { return n->uid() == node->uid(); }); @@ -127,6 +129,7 @@ void LoadPointer(TArchive &ar, TNode *&node) { dynamic_cast<TNode *>(found->get()) == node); if (ast_storage.storage_.end() == found) { ast_storage.storage_.emplace_back(node); + ast_storage.next_uid_ = std::max(ast_storage.next_uid_, node->uid() + 1); } } } diff --git a/src/query/frontend/semantic/symbol.hpp b/src/query/frontend/semantic/symbol.hpp index 3e6e85e12..a6becfd19 100644 --- a/src/query/frontend/semantic/symbol.hpp +++ b/src/query/frontend/semantic/symbol.hpp @@ -2,6 +2,9 @@ #include <string> +#include "boost/serialization/serialization.hpp" +#include "boost/serialization/string.hpp" + namespace query { class Symbol { diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index a01782617..a4d6f9b83 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -11,6 +11,7 @@ #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/interpret/frame.hpp" +#include "query/parameters.hpp" #include "query/typed_value.hpp" #include "utils/exceptions.hpp" diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index a091e0181..47845bf5c 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1,16 +1,23 @@ +#include "query/plan/operator.hpp" + #include <algorithm> #include <limits> +#include <string> #include <type_traits> #include <utility> +#include "boost/archive/binary_iarchive.hpp" +#include "boost/archive/binary_oarchive.hpp" +#include "boost/serialization/export.hpp" #include "glog/logging.h" +#include "database/graph_db_accessor.hpp" #include "query/context.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" +#include "query/frontend/semantic/symbol_table.hpp" #include "query/interpret/eval.hpp" - -#include "query/plan/operator.hpp" +#include "query/path.hpp" // macro for the default implementation of LogicalOperator::Accept // that accepts the visitor and visits it's input_ operator @@ -79,7 +86,7 @@ std::unique_ptr<Cursor> Once::MakeCursor(database::GraphDbAccessor &) const { void Once::OnceCursor::Reset() { did_pull_ = false; } -CreateNode::CreateNode(const NodeAtom *node_atom, +CreateNode::CreateNode(NodeAtom *node_atom, const std::shared_ptr<LogicalOperator> &input) : node_atom_(node_atom), input_(input ? input : std::make_shared<Once>()) {} @@ -117,7 +124,7 @@ void CreateNode::CreateNodeCursor::Create(Frame &frame, Context &context) { frame[context.symbol_table_.at(*self_.node_atom_->identifier_)] = new_node; } -CreateExpand::CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom, +CreateExpand::CreateExpand(NodeAtom *node_atom, EdgeAtom *edge_atom, const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node) : node_atom_(node_atom), @@ -2546,3 +2553,35 @@ void Union::UnionCursor::Reset() { } } // namespace query::plan + +BOOST_CLASS_EXPORT(query::plan::Once); +BOOST_CLASS_EXPORT(query::plan::CreateNode); +BOOST_CLASS_EXPORT(query::plan::CreateExpand); +BOOST_CLASS_EXPORT(query::plan::ScanAll); +BOOST_CLASS_EXPORT(query::plan::ScanAllByLabel); +BOOST_CLASS_EXPORT(query::plan::ScanAllByLabelPropertyRange); +BOOST_CLASS_EXPORT(query::plan::ScanAllByLabelPropertyValue); +BOOST_CLASS_EXPORT(query::plan::Expand); +BOOST_CLASS_EXPORT(query::plan::ExpandVariable); +BOOST_CLASS_EXPORT(query::plan::Filter); +BOOST_CLASS_EXPORT(query::plan::Produce); +BOOST_CLASS_EXPORT(query::plan::ConstructNamedPath); +BOOST_CLASS_EXPORT(query::plan::Delete); +BOOST_CLASS_EXPORT(query::plan::SetProperty); +BOOST_CLASS_EXPORT(query::plan::SetProperties); +BOOST_CLASS_EXPORT(query::plan::SetLabels); +BOOST_CLASS_EXPORT(query::plan::RemoveProperty); +BOOST_CLASS_EXPORT(query::plan::RemoveLabels); +BOOST_CLASS_EXPORT(query::plan::ExpandUniquenessFilter<EdgeAccessor>); +BOOST_CLASS_EXPORT(query::plan::ExpandUniquenessFilter<VertexAccessor>); +BOOST_CLASS_EXPORT(query::plan::Accumulate); +BOOST_CLASS_EXPORT(query::plan::Aggregate); +BOOST_CLASS_EXPORT(query::plan::Skip); +BOOST_CLASS_EXPORT(query::plan::Limit); +BOOST_CLASS_EXPORT(query::plan::OrderBy); +BOOST_CLASS_EXPORT(query::plan::Merge); +BOOST_CLASS_EXPORT(query::plan::Optional); +BOOST_CLASS_EXPORT(query::plan::Unwind); +BOOST_CLASS_EXPORT(query::plan::Distinct); +BOOST_CLASS_EXPORT(query::plan::CreateIndex); +BOOST_CLASS_EXPORT(query::plan::Union); diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index 0161f87b1..000c4d41a 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -2,33 +2,38 @@ #pragma once -#include <glog/logging.h> -#include <algorithm> -#include <deque> #include <experimental/optional> #include <memory> -#include <tuple> -#include <type_traits> #include <unordered_map> #include <unordered_set> +#include <utility> #include <vector> -#include "database/graph_db_accessor.hpp" +#include <boost/serialization/shared_ptr_helper.hpp> +#include "boost/serialization/base_object.hpp" +#include "boost/serialization/serialization.hpp" +#include "boost/serialization/shared_ptr.hpp" +#include "boost/serialization/unique_ptr.hpp" + #include "query/common.hpp" -#include "query/exceptions.hpp" -#include "query/frontend/semantic/symbol_table.hpp" -#include "query/path.hpp" +#include "query/frontend/ast/ast.hpp" +#include "query/frontend/semantic/symbol.hpp" #include "query/typed_value.hpp" #include "storage/types.hpp" #include "utils/bound.hpp" #include "utils/hashing/fnv.hpp" #include "utils/visitor.hpp" +namespace database { +class GraphDbAccessor; +} + namespace query { class Context; class ExpressionEvaluator; class Frame; +class SymbolTable; namespace plan { @@ -150,8 +155,23 @@ class LogicalOperator } virtual ~LogicalOperator() {} + + private: + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) {} }; +template <class TArchive> +std::pair<std::unique_ptr<LogicalOperator>, AstTreeStorage> LoadPlan( + TArchive &ar) { + std::unique_ptr<LogicalOperator> root; + ar >> root; + return {std::move(root), std::move(ar.template get_helper<AstTreeStorage>( + AstTreeStorage::kHelperId))}; +} + /** * A logical operator whose Cursor returns true on the first Pull * and false on every following Pull. @@ -172,6 +192,12 @@ class Once : public LogicalOperator { private: bool did_pull_{false}; }; + + friend class boost::serialization::access; + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + } }; /** @brief Operator for creating a node. @@ -192,15 +218,17 @@ class CreateNode : public LogicalOperator { * If a valid input, then a node will be created for each * successful pull from the given input. */ - CreateNode(const NodeAtom *node_atom, + CreateNode(NodeAtom *node_atom, const std::shared_ptr<LogicalOperator> &input); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; std::unique_ptr<Cursor> MakeCursor( database::GraphDbAccessor &db) const override; private: - const NodeAtom *node_atom_ = nullptr; - const std::shared_ptr<LogicalOperator> input_; + CreateNode() {} + + NodeAtom *node_atom_ = nullptr; + std::shared_ptr<LogicalOperator> input_; class CreateNodeCursor : public Cursor { public: @@ -218,6 +246,24 @@ class CreateNode : public LogicalOperator { */ void Create(Frame &, Context &); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, node_atom_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, node_atom_); + } }; /** @brief Operator for creating edges and destination nodes. @@ -247,7 +293,7 @@ class CreateExpand : public LogicalOperator { * @param existing_node @c bool indicating whether the @c node_atom refers to * an existing node. If @c false, the operator will also create the node. */ - CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom, + CreateExpand(NodeAtom *node_atom, EdgeAtom *edge_atom, const std::shared_ptr<LogicalOperator> &input, Symbol input_symbol, bool existing_node); bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; @@ -256,17 +302,19 @@ class CreateExpand : public LogicalOperator { private: // info on what's getting expanded - const NodeAtom *node_atom_; - const EdgeAtom *edge_atom_; + NodeAtom *node_atom_; + EdgeAtom *edge_atom_; // the input op and the symbol under which the op's result // can be found in the frame - const std::shared_ptr<LogicalOperator> input_; - const Symbol input_symbol_; + std::shared_ptr<LogicalOperator> input_; + Symbol input_symbol_; // if the given node atom refers to an existing node // (either matched or created) - const bool existing_node_; + bool existing_node_; + + CreateExpand() {} class CreateExpandCursor : public Cursor { public: @@ -298,6 +346,30 @@ class CreateExpand : public LogicalOperator { const SymbolTable &symbol_table, ExpressionEvaluator &evaluator); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + SavePointer(ar, node_atom_); + SavePointer(ar, edge_atom_); + ar &input_; + ar &input_symbol_; + ar &existing_node_; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + LoadPointer(ar, node_atom_); + LoadPointer(ar, edge_atom_); + ar &input_; + ar &input_symbol_; + ar &existing_node_; + } }; /** @@ -329,8 +401,8 @@ class ScanAll : public LogicalOperator { auto graph_view() const { return graph_view_; } protected: - const std::shared_ptr<LogicalOperator> input_; - const Symbol output_symbol_; + std::shared_ptr<LogicalOperator> input_; + Symbol output_symbol_; /** * @brief Controls which graph state is used to produce vertices. * @@ -339,7 +411,20 @@ class ScanAll : public LogicalOperator { * command. With @c GraphView::NEW, all vertices will be produced the current * transaction sees along with their modifications. */ - const GraphView graph_view_; + GraphView graph_view_; + + ScanAll() {} + + private: + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &output_symbol_; + ar &graph_view_; + } }; /** @@ -362,7 +447,17 @@ class ScanAllByLabel : public ScanAll { storage::Label label() const { return label_; } private: - const storage::Label label_; + storage::Label label_; + + ScanAllByLabel() {} + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<ScanAll>(*this); + ar &label_; + } }; /** @@ -408,10 +503,57 @@ class ScanAllByLabelPropertyRange : public ScanAll { auto upper_bound() const { return upper_bound_; } private: - const storage::Label label_; - const storage::Property property_; + storage::Label label_; + storage::Property property_; std::experimental::optional<Bound> lower_bound_; std::experimental::optional<Bound> upper_bound_; + + ScanAllByLabelPropertyRange() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<ScanAll>(*this); + ar &label_; + ar &property_; + auto save_bound = [&ar](auto &maybe_bound) { + if (!maybe_bound) { + ar & false; + return; + } + ar & true; + auto &bound = *maybe_bound; + ar &bound.type(); + SavePointer(ar, bound.value()); + }; + save_bound(lower_bound_); + save_bound(upper_bound_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<ScanAll>(*this); + ar &label_; + ar &property_; + auto load_bound = [&ar](auto &maybe_bound) { + bool has_bound = false; + ar &has_bound; + if (!has_bound) { + maybe_bound = std::experimental::nullopt; + return; + } + utils::BoundType type; + ar &type; + Expression *value; + LoadPointer(ar, value); + maybe_bound = std::experimental::make_optional(Bound(value, type)); + }; + load_bound(lower_bound_); + load_bound(upper_bound_); + } }; /** @@ -449,9 +591,31 @@ class ScanAllByLabelPropertyValue : public ScanAll { auto expression() const { return expression_; } private: - const storage::Label label_; - const storage::Property property_; + storage::Label label_; + storage::Property property_; Expression *expression_; + + ScanAllByLabelPropertyValue() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<ScanAll>(*this); + ar &label_; + ar &property_; + SavePointer(ar, expression_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<ScanAll>(*this); + ar &label_; + ar &property_; + LoadPointer(ar, expression_); + } }; /** @@ -511,22 +675,22 @@ class ExpandCommon { protected: // info on what's getting expanded - const Symbol node_symbol_; - const Symbol edge_symbol_; - const EdgeAtom::Direction direction_; - const std::vector<storage::EdgeType> edge_types_; + Symbol node_symbol_; + Symbol edge_symbol_; + EdgeAtom::Direction direction_; + std::vector<storage::EdgeType> edge_types_; // the input op and the symbol under which the op's result // can be found in the frame - const std::shared_ptr<LogicalOperator> input_; - const Symbol input_symbol_; + std::shared_ptr<LogicalOperator> input_; + Symbol input_symbol_; // If the given node atom refer to a symbol that has already been expanded and // should be just validated in the frame. - const bool existing_node_; + bool existing_node_; // from which state the input node should get expanded - const GraphView graph_view_; + GraphView graph_view_; /** * For a newly expanded node handles existence checking and @@ -537,6 +701,23 @@ class ExpandCommon { * the old. */ bool HandleExistingNode(const VertexAccessor &new_node, Frame &frame) const; + + ExpandCommon() {} + + private: + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &node_symbol_; + ar &edge_symbol_; + ar &direction_; + ar &edge_types_; + ar &input_; + ar &input_symbol_; + ar &existing_node_; + ar &graph_view_; + } }; /** @@ -586,6 +767,15 @@ class Expand : public LogicalOperator, public ExpandCommon { bool InitEdges(Frame &, Context &); }; + + private: + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &boost::serialization::base_object<ExpandCommon>(*this); + } }; /** @@ -655,20 +845,52 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon { auto type() const { return type_; } private: - const EdgeAtom::Type type_; + EdgeAtom::Type type_; // True if the path should be written as expanding from node_symbol to // input_symbol. - const bool is_reverse_; + bool is_reverse_; // lower and upper bounds of the variable length expansion // both are optional, defaults are (1, inf) Expression *lower_bound_; Expression *upper_bound_; // symbols for a single node and edge that are currently getting expanded - const Symbol inner_edge_symbol_; - const Symbol inner_node_symbol_; + Symbol inner_edge_symbol_; + Symbol inner_node_symbol_; // a filtering expression for skipping expansions during expansion // can refer to inner node and edges Expression *filter_; + + ExpandVariable() {} + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &boost::serialization::base_object<ExpandCommon>(*this); + ar &type_; + ar &is_reverse_; + SavePointer(ar, lower_bound_); + SavePointer(ar, upper_bound_); + ar &inner_edge_symbol_; + ar &inner_node_symbol_; + SavePointer(ar, filter_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &boost::serialization::base_object<ExpandCommon>(*this); + ar &type_; + ar &is_reverse_; + LoadPointer(ar, lower_bound_); + LoadPointer(ar, upper_bound_); + ar &inner_edge_symbol_; + ar &inner_node_symbol_; + LoadPointer(ar, filter_); + } }; /** @@ -691,9 +913,21 @@ class ConstructNamedPath : public LogicalOperator { const auto &path_elements() const { return path_elements_; } private: - const std::shared_ptr<LogicalOperator> input_; - const Symbol path_symbol_; - const std::vector<Symbol> path_elements_; + std::shared_ptr<LogicalOperator> input_; + Symbol path_symbol_; + std::vector<Symbol> path_elements_; + + ConstructNamedPath() {} + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &path_symbol_; + ar &path_elements_; + } }; /** @@ -713,9 +947,11 @@ class Filter : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; Expression *expression_; + Filter() {} + class FilterCursor : public Cursor { public: FilterCursor(const Filter &self, database::GraphDbAccessor &db); @@ -727,6 +963,24 @@ class Filter : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, expression_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, expression_); + } }; /** @@ -751,8 +1005,8 @@ class Produce : public LogicalOperator { const std::vector<NamedExpression *> &named_expressions(); private: - const std::shared_ptr<LogicalOperator> input_; - const std::vector<NamedExpression *> named_expressions_; + std::shared_ptr<LogicalOperator> input_; + std::vector<NamedExpression *> named_expressions_; class ProduceCursor : public Cursor { public: @@ -765,6 +1019,16 @@ class Produce : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr<Cursor> input_cursor_; }; + + Produce() {} + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &named_expressions_; + } }; /** @@ -782,12 +1046,14 @@ class Delete : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; - const std::vector<Expression *> expressions_; + std::shared_ptr<LogicalOperator> input_; + std::vector<Expression *> expressions_; // if the vertex should be detached before deletion // if not detached, and has connections, an error is raised // ignored when deleting edges - const bool detach_; + bool detach_; + + Delete() {} class DeleteCursor : public Cursor { public: @@ -800,6 +1066,26 @@ class Delete : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointers(ar, expressions_); + ar &detach_; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointers(ar, expressions_); + ar &detach_; + } }; /** @@ -817,10 +1103,12 @@ class SetProperty : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; PropertyLookup *lhs_; Expression *rhs_; + SetProperty() {} + class SetPropertyCursor : public Cursor { public: SetPropertyCursor(const SetProperty &self, database::GraphDbAccessor &db); @@ -832,6 +1120,26 @@ class SetProperty : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, lhs_); + SavePointer(ar, rhs_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, lhs_); + LoadPointer(ar, rhs_); + } }; /** @@ -865,11 +1173,13 @@ class SetProperties : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; - const Symbol input_symbol_; + std::shared_ptr<LogicalOperator> input_; + Symbol input_symbol_; Expression *rhs_; Op op_; + SetProperties() {} + class SetPropertiesCursor : public Cursor { public: SetPropertiesCursor(const SetProperties &self, @@ -890,6 +1200,28 @@ class SetProperties : public LogicalOperator { template <typename TRecordAccessor> void Set(TRecordAccessor &record, const TypedValue &rhs) const; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &input_symbol_; + SavePointer(ar, rhs_); + ar &op_; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &input_symbol_; + LoadPointer(ar, rhs_); + ar &op_; + } }; /** @@ -907,9 +1239,11 @@ class SetLabels : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; - const Symbol input_symbol_; - const std::vector<storage::Label> labels_; + std::shared_ptr<LogicalOperator> input_; + Symbol input_symbol_; + std::vector<storage::Label> labels_; + + SetLabels() {} class SetLabelsCursor : public Cursor { public: @@ -921,6 +1255,16 @@ class SetLabels : public LogicalOperator { const SetLabels &self_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &input_symbol_; + ar &labels_; + } }; /** @@ -936,9 +1280,11 @@ class RemoveProperty : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; PropertyLookup *lhs_; + RemoveProperty() {} + class RemovePropertyCursor : public Cursor { public: RemovePropertyCursor(const RemoveProperty &self, @@ -951,6 +1297,24 @@ class RemoveProperty : public LogicalOperator { database::GraphDbAccessor &db_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, lhs_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, lhs_); + } }; /** @@ -968,9 +1332,11 @@ class RemoveLabels : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; - const Symbol input_symbol_; - const std::vector<storage::Label> labels_; + std::shared_ptr<LogicalOperator> input_; + Symbol input_symbol_; + std::vector<storage::Label> labels_; + + RemoveLabels() {} class RemoveLabelsCursor : public Cursor { public: @@ -982,6 +1348,16 @@ class RemoveLabels : public LogicalOperator { const RemoveLabels &self_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &input_symbol_; + ar &labels_; + } }; /** @@ -1018,9 +1394,11 @@ class ExpandUniquenessFilter : public LogicalOperator { database::GraphDbAccessor &db) const override; private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; Symbol expand_symbol_; - const std::vector<Symbol> previous_symbols_; + std::vector<Symbol> previous_symbols_; + + ExpandUniquenessFilter() {} class ExpandUniquenessFilterCursor : public Cursor { public: @@ -1033,6 +1411,16 @@ class ExpandUniquenessFilter : public LogicalOperator { const ExpandUniquenessFilter &self_; const std::unique_ptr<Cursor> input_cursor_; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &expand_symbol_; + ar &previous_symbols_; + } }; /** @brief Pulls everything from the input before passing it through. @@ -1073,9 +1461,11 @@ class Accumulate : public LogicalOperator { const auto &symbols() const { return symbols_; }; private: - const std::shared_ptr<LogicalOperator> input_; - const std::vector<Symbol> symbols_; - const bool advance_command_; + std::shared_ptr<LogicalOperator> input_; + std::vector<Symbol> symbols_; + bool advance_command_; + + Accumulate() {} class AccumulateCursor : public Cursor { public: @@ -1091,6 +1481,16 @@ class Accumulate : public LogicalOperator { decltype(cache_.begin()) cache_it_ = cache_.begin(); bool pulled_all_input_{false}; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &symbols_; + ar &advance_command_; + } }; /** @@ -1127,6 +1527,27 @@ class Aggregate : public LogicalOperator { Expression *key; Aggregation::Op op; Symbol output_sym; + + private: + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + SavePointer(ar, value); + SavePointer(ar, key); + ar &op; + ar &output_sym; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + LoadPointer(ar, value); + LoadPointer(ar, key); + ar &op; + ar &output_sym; + } }; Aggregate(const std::shared_ptr<LogicalOperator> &input, @@ -1141,10 +1562,12 @@ class Aggregate : public LogicalOperator { const auto &group_by() const { return group_by_; } private: - const std::shared_ptr<LogicalOperator> input_; - const std::vector<Element> aggregations_; - const std::vector<Expression *> group_by_; - const std::vector<Symbol> remember_; + std::shared_ptr<LogicalOperator> input_; + std::vector<Element> aggregations_; + std::vector<Expression *> group_by_; + std::vector<Symbol> remember_; + + Aggregate() {} class AggregateCursor : public Cursor { public: @@ -1224,6 +1647,28 @@ class Aggregate : public LogicalOperator { * an appropriate exception is thrown. */ void EnsureOkForAvgSum(const TypedValue &value) const; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &aggregations_; + SavePointers(ar, group_by_); + ar &remember_; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &aggregations_; + LoadPointers(ar, group_by_); + ar &remember_; + } }; /** @brief Skips a number of Pulls from the input op. @@ -1247,9 +1692,11 @@ class Skip : public LogicalOperator { std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; Expression *expression_; + Skip() {} + class SkipCursor : public Cursor { public: SkipCursor(const Skip &self, database::GraphDbAccessor &db); @@ -1265,6 +1712,24 @@ class Skip : public LogicalOperator { int to_skip_{-1}; int skipped_{0}; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, expression_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, expression_); + } }; /** @brief Limits the number of Pulls from the input op. @@ -1291,9 +1756,11 @@ class Limit : public LogicalOperator { std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; Expression *expression_; + Limit() {} + class LimitCursor : public Cursor { public: LimitCursor(const Limit &self, database::GraphDbAccessor &db); @@ -1309,6 +1776,24 @@ class Limit : public LogicalOperator { int limit_{-1}; int pulled_{0}; }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, expression_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, expression_); + } }; /** @brief Logical operator for ordering (sorting) results. @@ -1349,12 +1834,21 @@ class OrderBy : public LogicalOperator { private: std::vector<Ordering> ordering_; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &ordering_; + } }; - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; TypedValueVectorCompare compare_; std::vector<Expression *> order_by_; - const std::vector<Symbol> output_symbols_; + std::vector<Symbol> output_symbols_; + + OrderBy() {} // custom comparison for TypedValue objects // behaves generally like Neo's ORDER BY comparison operator: @@ -1383,6 +1877,28 @@ class OrderBy : public LogicalOperator { // iterator over the cache_, maintains state between Pulls decltype(cache_.begin()) cache_it_ = cache_.begin(); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &compare_; + SavePointers(ar, order_by_); + ar &output_symbols_; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &compare_; + LoadPointers(ar, order_by_); + ar &output_symbols_; + } }; /** @@ -1412,9 +1928,11 @@ class Merge : public LogicalOperator { auto merge_create() const { return merge_create_; } private: - const std::shared_ptr<LogicalOperator> input_; - const std::shared_ptr<LogicalOperator> merge_match_; - const std::shared_ptr<LogicalOperator> merge_create_; + std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> merge_match_; + std::shared_ptr<LogicalOperator> merge_create_; + + Merge() {} class MergeCursor : public Cursor { public: @@ -1434,6 +1952,16 @@ class Merge : public LogicalOperator { // - previous Pull from this cursor exhausted the merge_match_cursor bool pull_input_{true}; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &merge_match_; + ar &merge_create_; + } }; /** @@ -1458,9 +1986,11 @@ class Optional : public LogicalOperator { const auto &optional_symbols() const { return optional_symbols_; } private: - const std::shared_ptr<LogicalOperator> input_; - const std::shared_ptr<LogicalOperator> optional_; - const std::vector<Symbol> optional_symbols_; + std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> optional_; + std::vector<Symbol> optional_symbols_; + + Optional() {} class OptionalCursor : public Cursor { public: @@ -1479,6 +2009,16 @@ class Optional : public LogicalOperator { // - previous Pull from this cursor exhausted the optional_cursor_ bool pull_input_{true}; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &optional_; + ar &optional_symbols_; + } }; /** @@ -1498,9 +2038,11 @@ class Unwind : public LogicalOperator { Expression *input_expression() const { return input_expression_; } private: - const std::shared_ptr<LogicalOperator> input_; + std::shared_ptr<LogicalOperator> input_; Expression *input_expression_; - const Symbol output_symbol_; + Symbol output_symbol_; + + Unwind() {} class UnwindCursor : public Cursor { public: @@ -1517,6 +2059,26 @@ class Unwind : public LogicalOperator { // current position in input_value_ std::vector<TypedValue>::iterator input_value_it_ = input_value_.end(); }; + + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + SavePointer(ar, input_expression_); + ar &output_symbol_; + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + LoadPointer(ar, input_expression_); + ar &output_symbol_; + } }; /** @@ -1537,8 +2099,10 @@ class Distinct : public LogicalOperator { std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr<LogicalOperator> input_; - const std::vector<Symbol> value_symbols_; + std::shared_ptr<LogicalOperator> input_; + std::vector<Symbol> value_symbols_; + + Distinct() {} class DistinctCursor : public Cursor { public: @@ -1558,6 +2122,15 @@ class Distinct : public LogicalOperator { TypedValueVectorEqual> seen_rows_; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &input_; + ar &value_symbols_; + } }; /** @@ -1581,6 +2154,17 @@ class CreateIndex : public LogicalOperator { private: storage::Label label_; storage::Property property_; + + CreateIndex() {} + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &label_; + ar &property_; + } }; /** @@ -1603,8 +2187,10 @@ class Union : public LogicalOperator { std::vector<Symbol> OutputSymbols(const SymbolTable &) const override; private: - const std::shared_ptr<LogicalOperator> left_op_, right_op_; - const std::vector<Symbol> union_symbols_, left_symbols_, right_symbols_; + std::shared_ptr<LogicalOperator> left_op_, right_op_; + std::vector<Symbol> union_symbols_, left_symbols_, right_symbols_; + + Union() {} class UnionCursor : public Cursor { public: @@ -1616,6 +2202,18 @@ class Union : public LogicalOperator { const Union &self_; const std::unique_ptr<Cursor> left_cursor_, right_cursor_; }; + + friend class boost::serialization::access; + + template <class TArchive> + void serialize(TArchive &ar, const unsigned int) { + ar &boost::serialization::base_object<LogicalOperator>(*this); + ar &left_op_; + ar &right_op_; + ar &union_symbols_; + ar &left_symbols_; + ar &right_symbols_; + } }; } // namespace plan diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index e78299861..4f1ba1527 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -1,7 +1,11 @@ #include <list> +#include <sstream> #include <tuple> +#include <typeinfo> #include <unordered_set> +#include "boost/archive/binary_iarchive.hpp" +#include "boost/archive/binary_oarchive.hpp" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -182,15 +186,23 @@ class ExpectAggregate : public OpChecker<Aggregate> { for (const auto &aggr_elem : op.aggregations()) { ASSERT_NE(aggr_it, aggregations_.end()); auto aggr = *aggr_it++; - EXPECT_EQ(aggr_elem.value, aggr->expression1_); - EXPECT_EQ(aggr_elem.key, aggr->expression2_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(aggr_elem.value).hash_code(), + typeid(aggr->expression1_).hash_code()); + EXPECT_EQ(typeid(aggr_elem.key).hash_code(), + typeid(aggr->expression2_).hash_code()); EXPECT_EQ(aggr_elem.op, aggr->op_); EXPECT_EQ(aggr_elem.output_sym, symbol_table.at(*aggr)); } EXPECT_EQ(aggr_it, aggregations_.end()); - auto got_group_by = std::unordered_set<query::Expression *>( - op.group_by().begin(), op.group_by().end()); - EXPECT_EQ(group_by_, got_group_by); + // TODO: Proper group by expression equality + std::unordered_set<size_t> got_group_by; + for (auto *expr : op.group_by()) + got_group_by.insert(typeid(*expr).hash_code()); + std::unordered_set<size_t> expected_group_by; + for (auto *expr : group_by_) + expected_group_by.insert(typeid(*expr).hash_code()); + EXPECT_EQ(got_group_by, expected_group_by); } private: @@ -252,7 +264,9 @@ class ExpectScanAllByLabelPropertyValue const SymbolTable &) override { EXPECT_EQ(scan_all.label(), label_); EXPECT_EQ(scan_all.property(), property_); - EXPECT_EQ(scan_all.expression(), expression_); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.expression()).hash_code(), + typeid(expression_).hash_code()); } private: @@ -279,12 +293,16 @@ class ExpectScanAllByLabelPropertyRange EXPECT_EQ(scan_all.property(), property_); if (lower_bound_) { ASSERT_TRUE(scan_all.lower_bound()); - EXPECT_EQ(scan_all.lower_bound()->value(), lower_bound_->value()); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.lower_bound()->value()).hash_code(), + typeid(lower_bound_->value()).hash_code()); EXPECT_EQ(scan_all.lower_bound()->type(), lower_bound_->type()); } if (upper_bound_) { ASSERT_TRUE(scan_all.upper_bound()); - EXPECT_EQ(scan_all.upper_bound()->value(), upper_bound_->value()); + // TODO: Proper expression equality + EXPECT_EQ(typeid(scan_all.upper_bound()->value()).hash_code(), + typeid(upper_bound_->value()).hash_code()); EXPECT_EQ(scan_all.upper_bound()->type(), upper_bound_->type()); } } @@ -318,6 +336,44 @@ auto MakeSymbolTable(query::Query &query) { return symbol_table; } +class Planner { + public: + Planner(std::vector<SingleQueryPart> single_query_parts, + PlanningContext<database::GraphDbAccessor> &context) { + plan_ = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>(single_query_parts, + context); + } + + auto &plan() { return *plan_; } + + private: + std::unique_ptr<LogicalOperator> plan_; +}; + +class SerializedPlanner { + public: + SerializedPlanner(std::vector<SingleQueryPart> single_query_parts, + PlanningContext<database::GraphDbAccessor> &context) { + std::stringstream stream; + { + auto original_plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( + single_query_parts, context); + boost::archive::binary_oarchive out_archive(stream); + out_archive << original_plan; + } + { + boost::archive::binary_iarchive in_archive(stream); + std::tie(plan_, ast_storage_) = LoadPlan(in_archive); + } + } + + auto &plan() { return *plan_; } + + private: + AstTreeStorage ast_storage_; + std::unique_ptr<LogicalOperator> plan_; +}; + template <class... TChecker> auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, TChecker... checker) { @@ -327,7 +383,7 @@ auto CheckPlan(LogicalOperator &plan, const SymbolTable &symbol_table, EXPECT_TRUE(plan_checker.checkers_.empty()); } -template <class... TChecker> +template <class TPlanner, class... TChecker> auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { auto symbol_table = MakeSymbolTable(*storage.query()); database::SingleNode db; @@ -336,19 +392,25 @@ auto CheckPlan(AstTreeStorage &storage, TChecker... checker) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, checker...); + TPlanner planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, checker...); } -TEST(TestLogicalPlanner, MatchNodeReturn) { +template <class T> +class TestPlanner : public ::testing::Test {}; + +using PlannerTypes = ::testing::Types<Planner, SerializedPlanner>; + +TYPED_TEST_CASE(TestPlanner, PlannerTypes); + +TYPED_TEST(TestPlanner, MatchNodeReturn) { // Test MATCH (n) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectProduce()); } -TEST(TestLogicalPlanner, CreateNodeReturn) { +TYPED_TEST(TestPlanner, CreateNodeReturn) { // Test CREATE (n) RETURN n AS n AstTreeStorage storage; auto ident_n = IDENT("n"); @@ -362,12 +424,12 @@ TEST(TestLogicalPlanner, CreateNodeReturn) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, + ExpectProduce()); } -TEST(TestLogicalPlanner, CreateExpand) { +TYPED_TEST(TestPlanner, CreateExpand) { // Test CREATE (n) -[r :rel1]-> (m) AstTreeStorage storage; database::SingleNode db; @@ -375,17 +437,17 @@ TEST(TestLogicalPlanner, CreateExpand) { auto relationship = dba.EdgeType("relationship"); QUERY(SINGLE_QUERY(CREATE(PATTERN( NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand()); + CheckPlan<TypeParam>(storage, ExpectCreateNode(), ExpectCreateExpand()); } -TEST(TestLogicalPlanner, CreateMultipleNode) { +TYPED_TEST(TestPlanner, CreateMultipleNode) { // Test CREATE (n), (m) AstTreeStorage storage; QUERY(SINGLE_QUERY(CREATE(PATTERN(NODE("n")), PATTERN(NODE("m"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateNode()); + CheckPlan<TypeParam>(storage, ExpectCreateNode(), ExpectCreateNode()); } -TEST(TestLogicalPlanner, CreateNodeExpandNode) { +TYPED_TEST(TestPlanner, CreateNodeExpandNode) { // Test CREATE (n) -[r :rel]-> (m), (l) AstTreeStorage storage; database::SingleNode db; @@ -394,11 +456,11 @@ TEST(TestLogicalPlanner, CreateNodeExpandNode) { QUERY(SINGLE_QUERY(CREATE( PATTERN(NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m")), PATTERN(NODE("l"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), - ExpectCreateNode()); + CheckPlan<TypeParam>(storage, ExpectCreateNode(), ExpectCreateExpand(), + ExpectCreateNode()); } -TEST(TestLogicalPlanner, CreateNamedPattern) { +TYPED_TEST(TestPlanner, CreateNamedPattern) { // Test CREATE p = (n) -[r :rel]-> (m) AstTreeStorage storage; database::SingleNode db; @@ -406,11 +468,11 @@ TEST(TestLogicalPlanner, CreateNamedPattern) { auto relationship = dba.EdgeType("rel"); QUERY(SINGLE_QUERY(CREATE(NAMED_PATTERN( "p", NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), - ExpectConstructNamedPath()); + CheckPlan<TypeParam>(storage, ExpectCreateNode(), ExpectCreateExpand(), + ExpectConstructNamedPath()); } -TEST(TestLogicalPlanner, MatchCreateExpand) { +TYPED_TEST(TestPlanner, MatchCreateExpand) { // Test MATCH (n) CREATE (n) -[r :rel1]-> (m) AstTreeStorage storage; database::SingleNode db; @@ -420,20 +482,20 @@ TEST(TestLogicalPlanner, MatchCreateExpand) { MATCH(PATTERN(NODE("n"))), CREATE(PATTERN(NODE("n"), EDGE("r", Direction::OUT, {relationship}), NODE("m"))))); - CheckPlan(storage, ExpectScanAll(), ExpectCreateExpand()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectCreateExpand()); } -TEST(TestLogicalPlanner, MatchLabeledNodes) { +TYPED_TEST(TestPlanner, MatchLabeledNodes) { // Test MATCH (n :label) RETURN n AstTreeStorage storage; database::SingleNode db; database::GraphDbAccessor dba(db); auto label = dba.Label("label"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), RETURN("n"))); - CheckPlan(storage, ExpectScanAllByLabel(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAllByLabel(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchPathReturn) { +TYPED_TEST(TestPlanner, MatchPathReturn) { // Test MATCH (n) -[r :relationship]- (m) RETURN n AstTreeStorage storage; database::SingleNode db; @@ -443,10 +505,11 @@ TEST(TestLogicalPlanner, MatchPathReturn) { MATCH(PATTERN(NODE("n"), EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchNamedPatternReturn) { +TYPED_TEST(TestPlanner, MatchNamedPatternReturn) { // Test MATCH p = (n) -[r :relationship]- (m) RETURN p AstTreeStorage storage; database::SingleNode db; @@ -457,11 +520,11 @@ TEST(TestLogicalPlanner, MatchNamedPatternReturn) { EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectConstructNamedPath(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectConstructNamedPath(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchNamedPatternWithPredicateReturn) { +TYPED_TEST(TestPlanner, MatchNamedPatternWithPredicateReturn) { // Test MATCH p = (n) -[r :relationship]- (m) RETURN p AstTreeStorage storage; database::SingleNode db; @@ -472,11 +535,12 @@ TEST(TestLogicalPlanner, MatchNamedPatternWithPredicateReturn) { EDGE("r", Direction::BOTH, {relationship}), NODE("m"))), WHERE(EQ(LITERAL(2), IDENT("p"))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), - ExpectConstructNamedPath(), ExpectFilter(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectConstructNamedPath(), ExpectFilter(), + ExpectProduce()); } -TEST(TestLogicalPlanner, OptionalMatchNamedPatternReturn) { +TYPED_TEST(TestPlanner, OptionalMatchNamedPatternReturn) { // Test OPTIONAL MATCH p = (n) -[r]- (m) RETURN p database::SingleNode db; database::GraphDbAccessor dba(db); @@ -491,9 +555,6 @@ TEST(TestLogicalPlanner, OptionalMatchNamedPatternReturn) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - std::list<BaseOpChecker *> optional{new ExpectScanAll(), new ExpectExpand(), new ExpectConstructNamedPath()}; auto get_symbol = [&symbol_table](const auto *ast_node) { @@ -501,11 +562,12 @@ TEST(TestLogicalPlanner, OptionalMatchNamedPatternReturn) { }; std::vector<Symbol> optional_symbols{get_symbol(pattern), get_symbol(node_n), get_symbol(edge), get_symbol(node_m)}; - CheckPlan(*plan, symbol_table, ExpectOptional(optional_symbols, optional), - ExpectProduce()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, + ExpectOptional(optional_symbols, optional), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWhereReturn) { +TYPED_TEST(TestPlanner, MatchWhereReturn) { // Test MATCH (n) WHERE n.property < 42 RETURN n AstTreeStorage storage; database::SingleNode db; @@ -514,17 +576,18 @@ TEST(TestLogicalPlanner, MatchWhereReturn) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), WHERE(LESS(PROPERTY_LOOKUP("n", property), LITERAL(42))), RETURN("n"))); - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectFilter(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchDelete) { +TYPED_TEST(TestPlanner, MatchDelete) { // Test MATCH (n) DELETE n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), DELETE(IDENT("n")))); - CheckPlan(storage, ExpectScanAll(), ExpectDelete()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectDelete()); } -TEST(TestLogicalPlanner, MatchNodeSet) { +TYPED_TEST(TestPlanner, MatchNodeSet) { // Test MATCH (n) SET n.prop = 42, n = n, n :label AstTreeStorage storage; database::SingleNode db; @@ -534,11 +597,11 @@ TEST(TestLogicalPlanner, MatchNodeSet) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), SET(PROPERTY_LOOKUP("n", prop), LITERAL(42)), SET("n", IDENT("n")), SET("n", {label}))); - CheckPlan(storage, ExpectScanAll(), ExpectSetProperty(), - ExpectSetProperties(), ExpectSetLabels()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectSetProperty(), + ExpectSetProperties(), ExpectSetLabels()); } -TEST(TestLogicalPlanner, MatchRemove) { +TYPED_TEST(TestPlanner, MatchRemove) { // Test MATCH (n) REMOVE n.prop REMOVE n :label AstTreeStorage storage; database::SingleNode db; @@ -547,11 +610,11 @@ TEST(TestLogicalPlanner, MatchRemove) { auto label = dba.Label("label"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), REMOVE(PROPERTY_LOOKUP("n", prop)), REMOVE("n", {label}))); - CheckPlan(storage, ExpectScanAll(), ExpectRemoveProperty(), - ExpectRemoveLabels()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectRemoveProperty(), + ExpectRemoveLabels()); } -TEST(TestLogicalPlanner, MatchMultiPattern) { +TYPED_TEST(TestPlanner, MatchMultiPattern) { // Test MATCH (n) -[r]- (m), (j) -[e]- (i) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), @@ -559,12 +622,12 @@ TEST(TestLogicalPlanner, MatchMultiPattern) { RETURN("n"))); // We expect the expansions after the first to have a uniqueness filter in a // single MATCH clause. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), - ExpectExpand(), ExpectExpandUniquenessFilter<EdgeAccessor>(), - ExpectProduce()); + CheckPlan<TypeParam>( + storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), ExpectExpand(), + ExpectExpandUniquenessFilter<EdgeAccessor>(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchMultiPatternSameStart) { +TYPED_TEST(TestPlanner, MatchMultiPatternSameStart) { // Test MATCH (n), (n) -[e]- (m) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY( @@ -572,10 +635,11 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameStart) { RETURN("n"))); // We expect the second pattern to generate only an Expand, since another // ScanAll would be redundant. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) { +TYPED_TEST(TestPlanner, MatchMultiPatternSameExpandStart) { // Test MATCH (n) -[r]- (m), (m) -[e]- (l) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m")), @@ -584,11 +648,12 @@ TEST(TestLogicalPlanner, MatchMultiPatternSameExpandStart) { // We expect the second pattern to generate only an Expand. Another // ScanAll would be redundant, as it would generate the nodes obtained from // expansion. Additionally, a uniqueness filter is expected. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), - ExpectExpandUniquenessFilter<EdgeAccessor>(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpandUniquenessFilter<EdgeAccessor>(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MultiMatch) { +TYPED_TEST(TestPlanner, MultiMatch) { // Test MATCH (n) -[r]- (m) MATCH (j) -[e]- (i) -[f]- (h) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY( @@ -597,12 +662,13 @@ TEST(TestLogicalPlanner, MultiMatch) { RETURN("n"))); // Multiple MATCH clauses form a Cartesian product, so the uniqueness should // not cross MATCH boundaries. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectScanAll(), - ExpectExpand(), ExpectExpand(), - ExpectExpandUniquenessFilter<EdgeAccessor>(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectScanAll(), ExpectExpand(), ExpectExpand(), + ExpectExpandUniquenessFilter<EdgeAccessor>(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MultiMatchSameStart) { +TYPED_TEST(TestPlanner, MultiMatchSameStart) { // Test MATCH (n) MATCH (n) -[r]- (m) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), @@ -610,19 +676,21 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) { RETURN("n"))); // Similar to MatchMultiPatternSameStart, we expect only Expand from second // MATCH clause. - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWithReturn) { +TYPED_TEST(TestPlanner, MatchWithReturn) { // Test MATCH (old) WITH old AS new RETURN new AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("old"))), WITH("old", AS("new")), RETURN("new"))); // No accumulation since we only do reads. - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectProduce(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWithWhereReturn) { +TYPED_TEST(TestPlanner, MatchWithWhereReturn) { // Test MATCH (old) WITH old AS new WHERE new.prop < 42 RETURN new database::SingleNode db; database::GraphDbAccessor dba(db); @@ -632,11 +700,11 @@ TEST(TestLogicalPlanner, MatchWithWhereReturn) { WHERE(LESS(PROPERTY_LOOKUP("new", prop), LITERAL(42))), RETURN("new"))); // No accumulation since we only do reads. - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectFilter(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectProduce(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, CreateMultiExpand) { +TYPED_TEST(TestPlanner, CreateMultiExpand) { // Test CREATE (n) -[r :r]-> (m), (n) - [p :p]-> (l) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -646,11 +714,11 @@ TEST(TestLogicalPlanner, CreateMultiExpand) { QUERY(SINGLE_QUERY( CREATE(PATTERN(NODE("n"), EDGE("r", Direction::OUT, {r}), NODE("m")), PATTERN(NODE("n"), EDGE("p", Direction::OUT, {p}), NODE("l"))))); - CheckPlan(storage, ExpectCreateNode(), ExpectCreateExpand(), - ExpectCreateExpand()); + CheckPlan<TypeParam>(storage, ExpectCreateNode(), ExpectCreateExpand(), + ExpectCreateExpand()); } -TEST(TestLogicalPlanner, MatchWithSumWhereReturn) { +TYPED_TEST(TestPlanner, MatchWithSumWhereReturn) { // Test MATCH (n) WITH SUM(n.prop) + 42 AS sum WHERE sum < 42 // RETURN sum AS result database::SingleNode db; @@ -663,11 +731,11 @@ TEST(TestLogicalPlanner, MatchWithSumWhereReturn) { MATCH(PATTERN(NODE("n"))), WITH(ADD(sum, literal), AS("sum")), WHERE(LESS(IDENT("sum"), LITERAL(42))), RETURN("sum", AS("result")))); auto aggr = ExpectAggregate({sum}, {literal}); - CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce(), ExpectFilter(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), aggr, ExpectProduce(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchReturnSum) { +TYPED_TEST(TestPlanner, MatchReturnSum) { // Test MATCH (n) RETURN SUM(n.prop1) AS sum, n.prop2 AS group database::SingleNode db; database::GraphDbAccessor dba(db); @@ -679,10 +747,10 @@ TEST(TestLogicalPlanner, MatchReturnSum) { QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN(sum, AS("sum"), n_prop2, AS("group")))); auto aggr = ExpectAggregate({sum}, {n_prop2}); - CheckPlan(storage, ExpectScanAll(), aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, CreateWithSum) { +TYPED_TEST(TestPlanner, CreateWithSum) { // Test CREATE (n) WITH SUM(n.prop) AS sum database::SingleNode db; database::GraphDbAccessor dba(db); @@ -699,15 +767,14 @@ TEST(TestLogicalPlanner, CreateWithSum) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We expect both the accumulation and aggregation because the part before // WITH updates the database. - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWithCreate) { +TYPED_TEST(TestPlanner, MatchWithCreate) { // Test MATCH (n) WITH n AS a CREATE (a) -[r :r]-> (b) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -717,19 +784,20 @@ TEST(TestLogicalPlanner, MatchWithCreate) { MATCH(PATTERN(NODE("n"))), WITH("n", AS("a")), CREATE( PATTERN(NODE("a"), EDGE("r", Direction::OUT, {r_type}), NODE("b"))))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectCreateExpand()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectProduce(), + ExpectCreateExpand()); } -TEST(TestLogicalPlanner, MatchReturnSkipLimit) { +TYPED_TEST(TestPlanner, MatchReturnSkipLimit) { // Test MATCH (n) RETURN n SKIP 2 LIMIT 1 AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), RETURN("n", SKIP(LITERAL(2)), LIMIT(LITERAL(1))))); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), - ExpectLimit()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectProduce(), ExpectSkip(), + ExpectLimit()); } -TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { +TYPED_TEST(TestPlanner, CreateWithSkipReturnLimit) { // Test CREATE (n) WITH n AS m SKIP 2 RETURN m LIMIT 1 AstTreeStorage storage; auto ident_n = IDENT("n"); @@ -744,18 +812,17 @@ TEST(TestLogicalPlanner, CreateWithSkipReturnLimit) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // Since we have a write query, we need to have Accumulate. This is a bit // different than Neo4j 3.0, which optimizes WITH followed by RETURN as a // single RETURN clause and then moves Skip and Limit before Accumulate. This // causes different behaviour. A newer version of Neo4j does the same thing as // us here (but who knows if they change it again). - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, ExpectProduce(), - ExpectSkip(), ExpectProduce(), ExpectLimit()); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, + ExpectProduce(), ExpectSkip(), ExpectProduce(), ExpectLimit()); } -TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { +TYPED_TEST(TestPlanner, CreateReturnSumSkipLimit) { // Test CREATE (n) RETURN SUM(n.prop) AS s SKIP 2 LIMIT 1 database::SingleNode db; database::GraphDbAccessor dba(db); @@ -773,13 +840,12 @@ TEST(TestLogicalPlanner, CreateReturnSumSkipLimit) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), - ExpectSkip(), ExpectLimit()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, + ExpectProduce(), ExpectSkip(), ExpectLimit()); } -TEST(TestLogicalPlanner, MatchReturnOrderBy) { +TYPED_TEST(TestPlanner, MatchReturnOrderBy) { // Test MATCH (n) RETURN n ORDER BY n.prop database::SingleNode db; database::GraphDbAccessor dba(db); @@ -787,10 +853,11 @@ TEST(TestLogicalPlanner, MatchReturnOrderBy) { AstTreeStorage storage; auto ret = RETURN("n", ORDER_BY(PROPERTY_LOOKUP("n", prop))); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), ret)); - CheckPlan(storage, ExpectScanAll(), ExpectProduce(), ExpectOrderBy()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectProduce(), + ExpectOrderBy()); } -TEST(TestLogicalPlanner, CreateWithOrderByWhere) { +TYPED_TEST(TestPlanner, CreateWithOrderByWhere) { // Test CREATE (n) -[r :r]-> (m) // WITH n AS new ORDER BY new.prop, r.prop WHERE m.prop < 42 database::SingleNode db; @@ -818,13 +885,13 @@ TEST(TestLogicalPlanner, CreateWithOrderByWhere) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), ExpectCreateExpand(), acc, - ExpectProduce(), ExpectOrderBy(), ExpectFilter()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), + ExpectCreateExpand(), acc, ExpectProduce(), ExpectOrderBy(), + ExpectFilter()); } -TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) { +TYPED_TEST(TestPlanner, ReturnAddSumCountOrderBy) { // Test RETURN SUM(1) + COUNT(2) AS result ORDER BY result AstTreeStorage storage; auto sum = SUM(LITERAL(1)); @@ -832,10 +899,10 @@ TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) { QUERY(SINGLE_QUERY( RETURN(ADD(sum, count), AS("result"), ORDER_BY(IDENT("result"))))); auto aggr = ExpectAggregate({sum, count}, {}); - CheckPlan(storage, aggr, ExpectProduce(), ExpectOrderBy()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce(), ExpectOrderBy()); } -TEST(TestLogicalPlanner, MatchMerge) { +TYPED_TEST(TestPlanner, MatchMerge) { // Test MATCH (n) MERGE (n) -[r :r]- (m) // ON MATCH SET n.prop = 42 ON CREATE SET m = n // RETURN n AS n @@ -862,9 +929,8 @@ TEST(TestLogicalPlanner, MatchMerge) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectScanAll(), + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectMerge(on_match, on_create), acc, ExpectProduce()); for (auto &op : on_match) delete op; on_match.clear(); @@ -872,7 +938,7 @@ TEST(TestLogicalPlanner, MatchMerge) { on_create.clear(); } -TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { +TYPED_TEST(TestPlanner, MatchOptionalMatchWhereReturn) { // Test MATCH (n) OPTIONAL MATCH (n) -[r]- (m) WHERE m.prop < 42 RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -884,29 +950,30 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) { RETURN("r"))); std::list<BaseOpChecker *> optional{new ExpectScanAll(), new ExpectExpand(), new ExpectFilter()}; - CheckPlan(storage, ExpectScanAll(), ExpectOptional(optional), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectOptional(optional), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchUnwindReturn) { +TYPED_TEST(TestPlanner, MatchUnwindReturn) { // Test MATCH (n) UNWIND [1,2,3] AS x RETURN n, x AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"))), UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("x")), RETURN("n", "x"))); - CheckPlan(storage, ExpectScanAll(), ExpectUnwind(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectUnwind(), + ExpectProduce()); } -TEST(TestLogicalPlanner, ReturnDistinctOrderBySkipLimit) { +TYPED_TEST(TestPlanner, ReturnDistinctOrderBySkipLimit) { // Test RETURN DISTINCT 1 ORDER BY 1 SKIP 1 LIMIT 1 AstTreeStorage storage; QUERY(SINGLE_QUERY(RETURN_DISTINCT(LITERAL(1), AS("1"), ORDER_BY(LITERAL(1)), SKIP(LITERAL(1)), LIMIT(LITERAL(1))))); - CheckPlan(storage, ExpectProduce(), ExpectDistinct(), ExpectOrderBy(), - ExpectSkip(), ExpectLimit()); + CheckPlan<TypeParam>(storage, ExpectProduce(), ExpectDistinct(), + ExpectOrderBy(), ExpectSkip(), ExpectLimit()); } -TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) { +TYPED_TEST(TestPlanner, CreateWithDistinctSumWhereReturn) { // Test CREATE (n) WITH DISTINCT SUM(n.prop) AS s WHERE s < 42 RETURN s database::SingleNode db; database::GraphDbAccessor dba(db); @@ -924,13 +991,12 @@ TEST(TestLogicalPlanner, CreateWithDistinctSumWhereReturn) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectCreateNode(), acc, aggr, ExpectProduce(), - ExpectDistinct(), ExpectFilter(), ExpectProduce()); + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectCreateNode(), acc, aggr, + ExpectProduce(), ExpectDistinct(), ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchCrossReferenceVariable) { +TYPED_TEST(TestPlanner, MatchCrossReferenceVariable) { // Test MATCH (n {prop: m.prop}), (m {prop: n.prop}) RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -945,11 +1011,11 @@ TEST(TestLogicalPlanner, MatchCrossReferenceVariable) { QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n), PATTERN(node_m)), RETURN("n"))); // We expect both ScanAll to come before filters (2 are joined into one), // because they need to populate the symbol values. - CheckPlan(storage, ExpectScanAll(), ExpectScanAll(), ExpectFilter(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectScanAll(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchWhereBeforeExpand) { +TYPED_TEST(TestPlanner, MatchWhereBeforeExpand) { // Test MATCH (n) -[r]- (m) WHERE n.prop < 42 RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -959,11 +1025,11 @@ TEST(TestLogicalPlanner, MatchWhereBeforeExpand) { WHERE(LESS(PROPERTY_LOOKUP("n", prop), LITERAL(42))), RETURN("n"))); // We expect Fitler to come immediately after ScanAll, since it only uses `n`. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MultiMatchWhere) { +TYPED_TEST(TestPlanner, MultiMatchWhere) { // Test MATCH (n) -[r]- (m) MATCH (l) WHERE n.prop < 42 RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -975,11 +1041,11 @@ TEST(TestLogicalPlanner, MultiMatchWhere) { RETURN("n"))); // Even though WHERE is in the second MATCH clause, we expect Filter to come // before second ScanAll, since it only uses the value from first ScanAll. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectScanAll(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectScanAll(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchOptionalMatchWhere) { +TYPED_TEST(TestPlanner, MatchOptionalMatchWhere) { // Test MATCH (n) -[r]- (m) OPTIONAL MATCH (l) WHERE n.prop < 42 RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -993,11 +1059,11 @@ TEST(TestLogicalPlanner, MatchOptionalMatchWhere) { // first ScanAll, it must remain part of the Optional. It should come before // optional ScanAll. std::list<BaseOpChecker *> optional{new ExpectFilter(), new ExpectScanAll()}; - CheckPlan(storage, ExpectScanAll(), ExpectExpand(), ExpectOptional(optional), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpand(), + ExpectOptional(optional), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchReturnAsterisk) { +TYPED_TEST(TestPlanner, MatchReturnAsterisk) { // Test MATCH (n) -[e]- (m) RETURN *, m.prop database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1012,19 +1078,18 @@ TEST(TestLogicalPlanner, MatchReturnAsterisk) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(), + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), ExpectProduce()); std::vector<std::string> output_names; - for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) { + for (const auto &output_symbol : planner.plan().OutputSymbols(symbol_table)) { output_names.emplace_back(output_symbol.name()); } std::vector<std::string> expected_names{"e", "m", "n", "m.prop"}; EXPECT_EQ(output_names, expected_names); } -TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { +TYPED_TEST(TestPlanner, MatchReturnAsteriskSum) { // Test MATCH (n) RETURN *, SUM(n.prop) AS s database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1039,9 +1104,8 @@ TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - auto *produce = dynamic_cast<Produce *>(plan.get()); + TypeParam planner(single_query_parts, planning_context); + auto *produce = dynamic_cast<Produce *>(&planner.plan()); ASSERT_TRUE(produce); const auto &named_expressions = produce->named_expressions(); ASSERT_EQ(named_expressions.size(), 2); @@ -1049,16 +1113,17 @@ TEST(TestLogicalPlanner, MatchReturnAsteriskSum) { dynamic_cast<query::Identifier *>(named_expressions[0]->expression_); ASSERT_TRUE(expanded_ident); auto aggr = ExpectAggregate({sum}, {expanded_ident}); - CheckPlan(*plan, symbol_table, ExpectScanAll(), aggr, ExpectProduce()); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), aggr, + ExpectProduce()); std::vector<std::string> output_names; - for (const auto &output_symbol : plan->OutputSymbols(symbol_table)) { + for (const auto &output_symbol : planner.plan().OutputSymbols(symbol_table)) { output_names.emplace_back(output_symbol.name()); } std::vector<std::string> expected_names{"n", "s"}; EXPECT_EQ(output_names, expected_names); } -TEST(TestLogicalPlanner, UnwindMergeNodeProperty) { +TYPED_TEST(TestPlanner, UnwindMergeNodeProperty) { // Test UNWIND [1] AS i MERGE (n {prop: i}) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1069,22 +1134,23 @@ TEST(TestLogicalPlanner, UnwindMergeNodeProperty) { SINGLE_QUERY(UNWIND(LIST(LITERAL(1)), AS("i")), MERGE(PATTERN(node_n)))); std::list<BaseOpChecker *> on_match{new ExpectScanAll(), new ExpectFilter()}; std::list<BaseOpChecker *> on_create{new ExpectCreateNode()}; - CheckPlan(storage, ExpectUnwind(), ExpectMerge(on_match, on_create)); + CheckPlan<TypeParam>(storage, ExpectUnwind(), + ExpectMerge(on_match, on_create)); for (auto &op : on_match) delete op; for (auto &op : on_create) delete op; } -TEST(TestLogicalPlanner, MultipleOptionalMatchReturn) { +TYPED_TEST(TestPlanner, MultipleOptionalMatchReturn) { // Test OPTIONAL MATCH (n) OPTIONAL MATCH (m) RETURN n AstTreeStorage storage; QUERY(SINGLE_QUERY(OPTIONAL_MATCH(PATTERN(NODE("n"))), OPTIONAL_MATCH(PATTERN(NODE("m"))), RETURN("n"))); std::list<BaseOpChecker *> optional{new ExpectScanAll()}; - CheckPlan(storage, ExpectOptional(optional), ExpectOptional(optional), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectOptional(optional), + ExpectOptional(optional), ExpectProduce()); } -TEST(TestLogicalPlanner, FunctionAggregationReturn) { +TYPED_TEST(TestPlanner, FunctionAggregationReturn) { // Test RETURN sqrt(SUM(2)) AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1092,17 +1158,17 @@ TEST(TestLogicalPlanner, FunctionAggregationReturn) { QUERY(SINGLE_QUERY( RETURN(FN("sqrt", sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, FunctionWithoutArguments) { +TYPED_TEST(TestPlanner, FunctionWithoutArguments) { // Test RETURN pi() AS pi AstTreeStorage storage; QUERY(SINGLE_QUERY(RETURN(FN("pi"), AS("pi")))); - CheckPlan(storage, ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectProduce()); } -TEST(TestLogicalPlanner, ListLiteralAggregationReturn) { +TYPED_TEST(TestPlanner, ListLiteralAggregationReturn) { // Test RETURN [SUM(2)] AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1110,10 +1176,10 @@ TEST(TestLogicalPlanner, ListLiteralAggregationReturn) { QUERY(SINGLE_QUERY( RETURN(LIST(sum), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MapLiteralAggregationReturn) { +TYPED_TEST(TestPlanner, MapLiteralAggregationReturn) { // Test RETURN {sum: SUM(2)} AS result, 42 AS group_by AstTreeStorage storage; database::SingleNode db; @@ -1123,10 +1189,10 @@ TEST(TestLogicalPlanner, MapLiteralAggregationReturn) { QUERY(SINGLE_QUERY(RETURN(MAP({PROPERTY_PAIR("sum"), sum}), AS("result"), group_by_literal, AS("group_by")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, EmptyListIndexAggregation) { +TYPED_TEST(TestPlanner, EmptyListIndexAggregation) { // Test RETURN [][SUM(2)] AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1139,10 +1205,10 @@ TEST(TestLogicalPlanner, EmptyListIndexAggregation) { // sub-expression of a binary operator which contains an aggregation. This is // similar to grouping by '1' in `RETURN 1 + SUM(2)`. auto aggr = ExpectAggregate({sum}, {empty_list, group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, ListSliceAggregationReturn) { +TYPED_TEST(TestPlanner, ListSliceAggregationReturn) { // Test RETURN [1, 2][0..SUM(2)] AS result, 42 AS group_by AstTreeStorage storage; auto sum = SUM(LITERAL(2)); @@ -1153,20 +1219,20 @@ TEST(TestLogicalPlanner, ListSliceAggregationReturn) { // Similarly to EmptyListIndexAggregation test, we expect grouping by list and // '42', because slicing is an operator. auto aggr = ExpectAggregate({sum}, {list, group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, ListWithAggregationAndGroupBy) { +TYPED_TEST(TestPlanner, ListWithAggregationAndGroupBy) { // Test RETURN [sum(2), 42] AstTreeStorage storage; auto sum = SUM(LITERAL(2)); auto group_by_literal = LITERAL(42); QUERY(SINGLE_QUERY(RETURN(LIST(sum, group_by_literal), AS("result")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, AggregatonWithListWithAggregationAndGroupBy) { +TYPED_TEST(TestPlanner, AggregatonWithListWithAggregationAndGroupBy) { // Test RETURN sum(2), [sum(3), 42] AstTreeStorage storage; auto sum2 = SUM(LITERAL(2)); @@ -1175,10 +1241,10 @@ TEST(TestLogicalPlanner, AggregatonWithListWithAggregationAndGroupBy) { QUERY(SINGLE_QUERY( RETURN(sum2, AS("sum2"), LIST(sum3, group_by_literal), AS("list")))); auto aggr = ExpectAggregate({sum2, sum3}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MapWithAggregationAndGroupBy) { +TYPED_TEST(TestPlanner, MapWithAggregationAndGroupBy) { // Test RETURN {lit: 42, sum: sum(2)} database::SingleNode db; AstTreeStorage storage; @@ -1188,10 +1254,10 @@ TEST(TestLogicalPlanner, MapWithAggregationAndGroupBy) { {PROPERTY_PAIR("lit"), group_by_literal}), AS("result")))); auto aggr = ExpectAggregate({sum}, {group_by_literal}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, CreateIndex) { +TYPED_TEST(TestPlanner, CreateIndex) { // Test CREATE INDEX ON :Label(property) database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1199,10 +1265,10 @@ TEST(TestLogicalPlanner, CreateIndex) { auto property = dba.Property("property"); AstTreeStorage storage; QUERY(SINGLE_QUERY(CREATE_INDEX_ON(label, property))); - CheckPlan(storage, ExpectCreateIndex(label, property)); + CheckPlan<TypeParam>(storage, ExpectCreateIndex(label, property)); } -TEST(TestLogicalPlanner, AtomIndexedLabelProperty) { +TYPED_TEST(TestPlanner, AtomIndexedLabelProperty) { // Test MATCH (n :label {property: 42, not_indexed: 0}) RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1227,16 +1293,14 @@ TEST(TestLogicalPlanner, AtomIndexedLabelProperty) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, AtomPropertyWhereLabelIndexing) { +TYPED_TEST(TestPlanner, AtomPropertyWhereLabelIndexing) { // Test MATCH (n {property: 42}) WHERE n.not_indexed AND n:label RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1261,16 +1325,14 @@ TEST(TestLogicalPlanner, AtomPropertyWhereLabelIndexing) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, WhereIndexedLabelProperty) { +TYPED_TEST(TestPlanner, WhereIndexedLabelProperty) { // Test MATCH (n :label) WHERE n.property = 42 RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1289,15 +1351,14 @@ TEST(TestLogicalPlanner, WhereIndexedLabelProperty) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, property, lit_42), ExpectProduce()); } } -TEST(TestLogicalPlanner, BestPropertyIndexed) { +TYPED_TEST(TestPlanner, BestPropertyIndexed) { // Test MATCH (n :label) WHERE n.property = 1 AND n.better = 42 RETURN n AstTreeStorage storage; database::SingleNode db; @@ -1329,15 +1390,14 @@ TEST(TestLogicalPlanner, BestPropertyIndexed) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label, better, lit_42), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, MultiPropertyIndexScan) { +TYPED_TEST(TestPlanner, MultiPropertyIndexScan) { // Test MATCH (n :label1), (m :label2) WHERE n.prop1 = 1 AND m.prop2 = 2 // RETURN n, m database::SingleNode db; @@ -1361,15 +1421,14 @@ TEST(TestLogicalPlanner, MultiPropertyIndexScan) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyValue(label1, prop1, lit_1), ExpectScanAllByLabelPropertyValue(label2, prop2, lit_2), ExpectProduce()); } -TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { +TYPED_TEST(TestPlanner, WhereIndexedLabelPropertyRange) { // Test MATCH (n :label) WHERE n.property REL_OP 42 RETURN n // REL_OP is one of: `<`, `<=`, `>`, `>=` database::SingleNode db; @@ -1380,8 +1439,9 @@ TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { AstTreeStorage storage; auto lit_42 = LITERAL(42); auto n_prop = PROPERTY_LOOKUP("n", property); - auto check_planned_range = [&label, &property, &dba]( - const auto &rel_expr, auto lower_bound, auto upper_bound) { + auto check_planned_range = [&label, &property, &dba](const auto &rel_expr, + auto lower_bound, + auto upper_bound) { // Shadow the first storage, so that the query is created in this one. AstTreeStorage storage; QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n", label))), WHERE(rel_expr), @@ -1391,9 +1451,8 @@ TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - CheckPlan(*plan, symbol_table, + TypeParam planner(single_query_parts, planning_context); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabelPropertyRange(label, property, lower_bound, upper_bound), ExpectProduce()); @@ -1424,7 +1483,7 @@ TEST(TestLogicalPlanner, WhereIndexedLabelPropertyRange) { } } -TEST(TestLogicalPlanner, UnableToUsePropertyIndex) { +TYPED_TEST(TestPlanner, UnableToUsePropertyIndex) { // Test MATCH (n: label) WHERE n.property = n.property RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1443,16 +1502,15 @@ TEST(TestLogicalPlanner, UnableToUsePropertyIndex) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We can only get ScanAllByLabelIndex, because we are comparing properties // with those on the same node. - CheckPlan(*plan, symbol_table, ExpectScanAllByLabel(), ExpectFilter(), - ExpectProduce()); + CheckPlan(planner.plan(), symbol_table, ExpectScanAllByLabel(), + ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, SecondPropertyIndex) { +TYPED_TEST(TestPlanner, SecondPropertyIndex) { // Test MATCH (n :label), (m :label) WHERE m.property = n.property RETURN n database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1472,44 +1530,45 @@ TEST(TestLogicalPlanner, SecondPropertyIndex) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); CheckPlan( - *plan, symbol_table, ExpectScanAllByLabel(), + planner.plan(), symbol_table, ExpectScanAllByLabel(), // Note: We are scanning for m, therefore property should equal n_prop. ExpectScanAllByLabelPropertyValue(label, property, n_prop), ExpectProduce()); } } -TEST(TestLogicalPlanner, ReturnSumGroupByAll) { +TYPED_TEST(TestPlanner, ReturnSumGroupByAll) { // Test RETURN sum([1,2,3]), all(x in [1] where x = 1) AstTreeStorage storage; auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3))); auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1)))); QUERY(SINGLE_QUERY(RETURN(sum, AS("sum"), all, AS("all")))); auto aggr = ExpectAggregate({sum}, {all}); - CheckPlan(storage, aggr, ExpectProduce()); + CheckPlan<TypeParam>(storage, aggr, ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariable) { +TYPED_TEST(TestPlanner, MatchExpandVariable) { // Test MATCH (n) -[r *..3]-> (m) RETURN r AstTreeStorage storage; auto edge = EDGE_VARIABLE("r"); edge->upper_bound_ = LITERAL(3); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpandVariable(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableNoBounds) { +TYPED_TEST(TestPlanner, MatchExpandVariableNoBounds) { // Test MATCH (n) -[r *]-> (m) RETURN r AstTreeStorage storage; auto edge = EDGE_VARIABLE("r"); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpandVariable(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) { +TYPED_TEST(TestPlanner, MatchExpandVariableInlinedFilter) { // Test MATCH (n) -[r :type * {prop: 42}]-> (m) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1519,12 +1578,13 @@ TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) { auto edge = EDGE_VARIABLE("r", Direction::BOTH, {type}); edge->properties_[prop] = LITERAL(42); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), - ExpectExpandVariable(), // Filter is both inlined and post-expand - ExpectFilter(), ExpectProduce()); + CheckPlan<TypeParam>( + storage, ExpectScanAll(), + ExpectExpandVariable(), // Filter is both inlined and post-expand + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) { +TYPED_TEST(TestPlanner, MatchExpandVariableNotInlinedFilter) { // Test MATCH (n) -[r :type * {prop: m.prop}]-> (m) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1534,11 +1594,11 @@ TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) { auto edge = EDGE_VARIABLE("r", Direction::BOTH, {type}); edge->properties_[prop] = EQ(PROPERTY_LOOKUP("m", prop), LITERAL(42)); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectFilter(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpandVariable(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, UnwindMatchVariable) { +TYPED_TEST(TestPlanner, UnwindMatchVariable) { // Test UNWIND [1,2,3] AS depth MATCH (n) -[r*d]-> (m) RETURN r AstTreeStorage storage; auto edge = EDGE_VARIABLE("r", Direction::OUT); @@ -1546,11 +1606,11 @@ TEST(TestLogicalPlanner, UnwindMatchVariable) { edge->upper_bound_ = IDENT("d"); QUERY(SINGLE_QUERY(UNWIND(LIST(LITERAL(1), LITERAL(2), LITERAL(3)), AS("d")), MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectUnwind(), ExpectScanAll(), ExpectExpandVariable(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectUnwind(), ExpectScanAll(), + ExpectExpandVariable(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchBreadthFirst) { +TYPED_TEST(TestPlanner, MatchBreadthFirst) { // Test MATCH (n) -[r:type *..10 (r, n|n)]-> (m) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1564,11 +1624,11 @@ TEST(TestLogicalPlanner, MatchBreadthFirst) { bfs->filter_expression_ = IDENT("n"); bfs->upper_bound_ = LITERAL(10); QUERY(SINGLE_QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r"))); - CheckPlan(storage, ExpectScanAll(), ExpectExpandBreadthFirst(), - ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectExpandBreadthFirst(), + ExpectProduce()); } -TEST(TestLogicalPlanner, MatchDoubleScanToExpandExisting) { +TYPED_TEST(TestPlanner, MatchDoubleScanToExpandExisting) { // Test MATCH (n) -[r]- (m :label) RETURN r database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1581,15 +1641,14 @@ TEST(TestLogicalPlanner, MatchDoubleScanToExpandExisting) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We expect 2x ScanAll and then Expand, since we are guessing that is // faster (due to low label index vertex count). - CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectScanAllByLabel(), - ExpectExpand(), ExpectProduce()); + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), + ExpectScanAllByLabel(), ExpectExpand(), ExpectProduce()); } -TEST(TestLogicalPlanner, MatchScanToExpand) { +TYPED_TEST(TestPlanner, MatchScanToExpand) { // Test MATCH (n) -[r]- (m :label {property: 1}) RETURN r database::SingleNode db; auto label = database::GraphDbAccessor(db).Label("label"); @@ -1619,16 +1678,15 @@ TEST(TestLogicalPlanner, MatchScanToExpand) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); + TypeParam planner(single_query_parts, planning_context); // We expect 1x ScanAllByLabel and then Expand, since we are guessing that // is faster (due to high label index vertex count). - CheckPlan(*plan, symbol_table, ExpectScanAll(), ExpectExpand(), + CheckPlan(planner.plan(), symbol_table, ExpectScanAll(), ExpectExpand(), ExpectFilter(), ExpectProduce()); } } -TEST(TestLogicalPlanner, MatchWhereAndSplit) { +TYPED_TEST(TestPlanner, MatchWhereAndSplit) { // Test MATCH (n) -[r]- (m) WHERE n.prop AND r.prop RETURN m database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1639,11 +1697,11 @@ TEST(TestLogicalPlanner, MatchWhereAndSplit) { WHERE(AND(PROPERTY_LOOKUP("n", prop), PROPERTY_LOOKUP("r", prop))), RETURN("m"))); // We expect `n.prop` filter right after scanning `n`. - CheckPlan(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), - ExpectFilter(), ExpectProduce()); + CheckPlan<TypeParam>(storage, ExpectScanAll(), ExpectFilter(), ExpectExpand(), + ExpectFilter(), ExpectProduce()); } -TEST(TestLogicalPlanner, ReturnAsteriskOmitsLambdaSymbols) { +TYPED_TEST(TestPlanner, ReturnAsteriskOmitsLambdaSymbols) { // Test MATCH (n) -[r* (ie, in | true)]- (m) RETURN * database::SingleNode db; database::GraphDbAccessor dba(db); @@ -1660,9 +1718,8 @@ TEST(TestLogicalPlanner, ReturnAsteriskOmitsLambdaSymbols) { auto query_parts = CollectQueryParts(symbol_table, storage); ASSERT_TRUE(query_parts.query_parts.size() > 0); auto single_query_parts = query_parts.query_parts.at(0).single_query_parts; - auto plan = MakeLogicalPlanForSingleQuery<RuleBasedPlanner>( - single_query_parts, planning_context); - auto *produce = dynamic_cast<Produce *>(plan.get()); + TypeParam planner(single_query_parts, planning_context); + auto *produce = dynamic_cast<Produce *>(&planner.plan()); ASSERT_TRUE(produce); std::vector<std::string> outputs; for (const auto &output_symbol : produce->OutputSymbols(symbol_table)) {