From a83bea0b74a4b6ac9d276939f4a679dc51fa85de Mon Sep 17 00:00:00 2001 From: Teon Banek Date: Wed, 13 Sep 2017 10:27:12 +0200 Subject: [PATCH] Add ParameterLookup to AST Reviewers: florijan, mislav.bradac Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D782 --- src/query/context.hpp | 22 +- src/query/frontend/ast/ast.hpp | 27 +- src/query/frontend/ast/ast_visitor.hpp | 13 +- .../frontend/semantic/symbol_generator.hpp | 1 + src/query/interpret/eval.hpp | 9 +- src/query/interpreter.hpp | 26 +- src/query/plan/operator.cpp | 247 +++++++++--------- src/query/plan/operator.hpp | 60 +++-- src/query/plan/rule_based_planner.cpp | 2 + tests/manual/query_planner.cpp | 3 +- tests/unit/cypher_main_visitor.cpp | 2 +- tests/unit/query_expression_evaluator.cpp | 17 +- tests/unit/query_plan_common.hpp | 9 +- .../query_plan_create_set_remove_delete.cpp | 4 +- tests/unit/query_plan_match_filter_return.cpp | 8 +- 15 files changed, 254 insertions(+), 196 deletions(-) diff --git a/src/query/context.hpp b/src/query/context.hpp index cb275d644..def5e57ae 100644 --- a/src/query/context.hpp +++ b/src/query/context.hpp @@ -2,21 +2,17 @@ #include "antlr4-runtime.h" #include "database/graph_db_accessor.hpp" +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/parameters.hpp" namespace query { -/** - * Future-proofing for the time when we'll actually have - * something to configure in query execution. - */ -struct Config { +class Context { + public: + Context(GraphDbAccessor &db_accessor) : db_accessor_(db_accessor) {} + GraphDbAccessor &db_accessor_; + SymbolTable symbol_table_; + Parameters parameters_; }; -class Context { -public: - Context(Config config, GraphDbAccessor &db_accessor) - : config_(config), db_accessor_(db_accessor) {} - Config config_; - GraphDbAccessor &db_accessor_; -}; -} +} // namespace query diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index ddeb9a05e..c10c0f7e8 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -16,14 +16,14 @@ namespace query { #define CLONE_BINARY_EXPRESSION \ auto Clone(AstTreeStorage &storage) const->std::remove_const< \ - std::remove_pointer::type>::type * override { \ + std::remove_pointer::type>::type *override { \ return storage.Create< \ std::remove_cv::type>::type>( \ expression1_->Clone(storage), expression2_->Clone(storage)); \ } #define CLONE_UNARY_EXPRESSION \ auto Clone(AstTreeStorage &storage) const->std::remove_const< \ - std::remove_pointer::type>::type * override { \ + std::remove_pointer::type>::type *override { \ return storage.Create< \ std::remove_cv::type>::type>( \ expression_->Clone(storage)); \ @@ -901,6 +901,28 @@ class All : public Expression { } }; +class ParameterLookup : public Expression { + friend class AstTreeStorage; + + public: + DEFVISITABLE(TreeVisitor); + DEFVISITABLE(HierarchicalTreeVisitor); + + ParameterLookup *Clone(AstTreeStorage &storage) const override { + return storage.Create(token_position_); + } + + // This field contains token position of *literal* used to create + // ParameterLookup object. If ParameterLookup object is not created from + // a literal leave this value at -1. + int token_position_ = -1; + + protected: + ParameterLookup(int uid) : Expression(uid) {} + ParameterLookup(int uid, int token_position) + : Expression(uid), token_position_(token_position) {} +}; + class NamedExpression : public Tree { friend class AstTreeStorage; @@ -1652,6 +1674,7 @@ class CachedAst { } bool Visit(Identifier &) override { return true; } + bool Visit(ParameterLookup &) override { return true; } bool Visit(CreateIndex &) override { return true; } bool PreVisit(Return &) override { diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index a53559d1a..7aa7404d9 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -14,6 +14,7 @@ class EdgeTypeTest; class Aggregation; class Function; class All; +class ParameterLookup; class Create; class Match; class Return; @@ -71,8 +72,8 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< Pattern, NodeAtom, EdgeAtom, BreadthFirstAtom, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>; -using TreeLeafVisitor = - ::utils::LeafVisitor; +using TreeLeafVisitor = ::utils::LeafVisitor; class HierarchicalTreeVisitor : public TreeCompositeVisitor, public TreeLeafVisitor { @@ -92,9 +93,9 @@ using TreeVisitor = ::utils::Visitor< GreaterEqualOperator, InListOperator, ListMapIndexingOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, - EdgeTypeTest, Aggregation, Function, All, Create, Match, Return, With, - Pattern, NodeAtom, EdgeAtom, BreadthFirstAtom, Delete, Where, SetProperty, - SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind, - Identifier, PrimitiveLiteral, CreateIndex>; + EdgeTypeTest, Aggregation, Function, All, ParameterLookup, Create, Match, + Return, With, Pattern, NodeAtom, EdgeAtom, BreadthFirstAtom, Delete, Where, + SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, + Unwind, Identifier, PrimitiveLiteral, CreateIndex>; } // namespace query diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index d151b77d6..0b6791375 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -42,6 +42,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { // Expressions ReturnType Visit(Identifier &) override; ReturnType Visit(PrimitiveLiteral &) override { return true; } + ReturnType Visit(ParameterLookup &) override { return true; } bool PreVisit(Aggregation &) override; bool PostVisit(Aggregation &) override; bool PreVisit(IfOperator &) override; diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 7b20af1f5..c7487b5a5 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -19,10 +19,12 @@ namespace query { class ExpressionEvaluator : public TreeVisitor { public: - ExpressionEvaluator(Frame &frame, const SymbolTable &symbol_table, + ExpressionEvaluator(Frame &frame, const Parameters ¶meters, + const SymbolTable &symbol_table, GraphDbAccessor &db_accessor, GraphView graph_view = GraphView::AS_IS) : frame_(frame), + parameters_(parameters), symbol_table_(symbol_table), db_accessor_(db_accessor), graph_view_(graph_view) {} @@ -388,6 +390,10 @@ class ExpressionEvaluator : public TreeVisitor { return true; } + TypedValue Visit(ParameterLookup ¶m_lookup) override { + return parameters_.AtTokenPosition(param_lookup.token_position_); + } + private: // If the given TypedValue contains accessors, switch them to New or Old, // depending on use_new_ flag. @@ -438,6 +444,7 @@ class ExpressionEvaluator : public TreeVisitor { } Frame &frame_; + const Parameters ¶meters_; const SymbolTable &symbol_table_; GraphDbAccessor &db_accessor_; // which switching approach should be used when evaluating diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index 0d71d8460..73a20d601 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -34,8 +34,7 @@ class Interpreter { Stream &stream, const std::map ¶ms) { utils::Timer frontend_timer; - Config config; - Context ctx(config, db_accessor); + Context ctx(db_accessor); std::map summary; // stripped query -> high level tree @@ -88,26 +87,25 @@ class Interpreter { .first; } - // Update literals map with provided parameters. - auto literals = stripped.literals(); + // Update context with provided parameters. + ctx.parameters_ = stripped.literals(); for (const auto ¶m_pair : stripped.parameters()) { auto param_it = params.find(param_pair.second); if (param_it == params.end()) { throw query::UnprovidedParameterError( fmt::format("Parameter$ {} not provided", param_pair.second)); } - literals.Add(param_pair.first, param_it->second); + ctx.parameters_.Add(param_pair.first, param_it->second); } // Plug literals, parameters and named expressions. - return it->second.Plug(literals, stripped.named_expressions()); + return it->second.Plug(ctx.parameters_, stripped.named_expressions()); }(); auto frontend_time = frontend_timer.Elapsed(); utils::Timer planning_timer; // symbol table fill - SymbolTable symbol_table; - SymbolGenerator symbol_generator(symbol_table); + SymbolGenerator symbol_generator(ctx.symbol_table_); ast_storage.query()->Accept(symbol_generator); // high level tree -> logical plan @@ -116,7 +114,7 @@ class Interpreter { double query_plan_cost_estimation = 0.0; if (FLAGS_query_cost_planner) { auto plans = plan::MakeLogicalPlan( - ast_storage, symbol_table, vertex_counts); + ast_storage, ctx.symbol_table_, vertex_counts); double min_cost = std::numeric_limits::max(); for (auto &plan : plans) { auto cost = EstimatePlanCost(vertex_counts, *plan); @@ -130,19 +128,19 @@ class Interpreter { query_plan_cost_estimation = min_cost; } else { logical_plan = plan::MakeLogicalPlan( - ast_storage, symbol_table, vertex_counts); + ast_storage, ctx.symbol_table_, vertex_counts); query_plan_cost_estimation = EstimatePlanCost(vertex_counts, *logical_plan); } // generate frame based on symbol table max_position - Frame frame(symbol_table.max_position()); + Frame frame(ctx.symbol_table_.max_position()); auto planning_time = planning_timer.Elapsed(); utils::Timer execution_timer; std::vector header; std::vector output_symbols( - logical_plan->OutputSymbols(symbol_table)); + logical_plan->OutputSymbols(ctx.symbol_table_)); if (!output_symbols.empty()) { // Since we have output symbols, this means that the query contains RETURN // clause, so stream out the results. @@ -153,7 +151,7 @@ class Interpreter { // stream out results auto cursor = logical_plan->MakeCursor(db_accessor); - while (cursor->Pull(frame, symbol_table)) { + while (cursor->Pull(frame, ctx)) { std::vector values; for (const auto &symbol : output_symbols) values.emplace_back(frame[symbol]); @@ -171,7 +169,7 @@ class Interpreter { dynamic_cast(logical_plan.get())) { stream.Header(header); auto cursor = logical_plan->MakeCursor(db_accessor); - while (cursor->Pull(frame, symbol_table)) continue; + while (cursor->Pull(frame, ctx)) continue; } else { throw QueryRuntimeException("Unknown top level LogicalOperator"); } diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 062b511c0..f8c229404 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -5,6 +5,7 @@ #include "query/plan/operator.hpp" +#include "query/context.hpp" #include "query/exceptions.hpp" #include "query/frontend/ast/ast.hpp" #include "query/interpret/eval.hpp" @@ -59,7 +60,7 @@ bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) { } // namespace -bool Once::OnceCursor::Pull(Frame &, const SymbolTable &) { +bool Once::OnceCursor::Pull(Frame &, Context &) { if (!did_pull_) { did_pull_ = true; return true; @@ -87,10 +88,9 @@ CreateNode::CreateNodeCursor::CreateNodeCursor(const CreateNode &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool CreateNode::CreateNodeCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (input_cursor_->Pull(frame, symbol_table)) { - Create(frame, symbol_table); +bool CreateNode::CreateNodeCursor::Pull(Frame &frame, Context &context) { + if (input_cursor_->Pull(frame, context)) { + Create(frame, context); return true; } return false; @@ -98,17 +98,17 @@ bool CreateNode::CreateNodeCursor::Pull(Frame &frame, void CreateNode::CreateNodeCursor::Reset() { input_cursor_->Reset(); } -void CreateNode::CreateNodeCursor::Create(Frame &frame, - const SymbolTable &symbol_table) { +void CreateNode::CreateNodeCursor::Create(Frame &frame, Context &context) { auto new_node = db_.InsertVertex(); for (auto label : self_.node_atom_->labels_) new_node.add_label(label); // Evaluator should use the latest accessors, as modified in this query, when // setting properties on new nodes. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); for (auto &kv : self_.node_atom_->properties_) PropsSetChecked(new_node, kv.first.second, kv.second->Accept(evaluator)); - frame[symbol_table.at(*self_.node_atom_->identifier_)] = new_node; + frame[context.symbol_table_.at(*self_.node_atom_->identifier_)] = new_node; } CreateExpand::CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom, @@ -130,9 +130,8 @@ CreateExpand::CreateExpandCursor::CreateExpandCursor(const CreateExpand &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; // get the origin vertex TypedValue &vertex_value = frame[self_.input_symbol_]; @@ -141,28 +140,29 @@ bool CreateExpand::CreateExpandCursor::Pull(Frame &frame, // Similarly to CreateNode, newly created edges and nodes should use the // latest accesors. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); // E.g. we pickup new properties: `CREATE (n {p: 42}) -[:r {ep: n.p}]-> ()` v1.SwitchNew(); // get the destination vertex (possibly an existing node) - auto &v2 = OtherVertex(frame, symbol_table, evaluator); + auto &v2 = OtherVertex(frame, context.symbol_table_, evaluator); v2.SwitchNew(); // create an edge between the two nodes switch (self_.edge_atom_->direction_) { case EdgeAtom::Direction::IN: - CreateEdge(v2, v1, frame, symbol_table, evaluator); + CreateEdge(v2, v1, frame, context.symbol_table_, evaluator); break; case EdgeAtom::Direction::OUT: - CreateEdge(v1, v2, frame, symbol_table, evaluator); + CreateEdge(v1, v2, frame, context.symbol_table_, evaluator); break; case EdgeAtom::Direction::BOTH: // in the case of an undirected CreateExpand we choose an arbitrary // direction. this is used in the MERGE clause // it is not allowed in the CREATE clause, and the semantic // checker needs to ensure it doesn't reach this point - CreateEdge(v1, v2, frame, symbol_table, evaluator); + CreateEdge(v1, v2, frame, context.symbol_table_, evaluator); } return true; @@ -211,13 +211,13 @@ class ScanAllCursor : public Cursor { get_vertices_(std::move(get_vertices)), db_(db) {} - bool Pull(Frame &frame, const SymbolTable &symbol_table) override { + bool Pull(Frame &frame, Context &context) override { if (db_.should_abort()) throw HintedAbortError(); if (!vertices_ || vertices_it_.value() == vertices_.value().end()) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; + if (!input_cursor_->Pull(frame, context)) return false; // We need a getter function, because in case of exhausting a lazy // iterable, we cannot simply reset it by calling begin(). - vertices_.emplace(get_vertices_(frame, symbol_table)); + vertices_.emplace(get_vertices_(frame, context)); vertices_it_.emplace(vertices_.value().begin()); } @@ -240,7 +240,7 @@ class ScanAllCursor : public Cursor { const std::unique_ptr input_cursor_; TVerticesFun get_vertices_; std::experimental::optional< - typename std::result_of::type> + typename std::result_of::type> vertices_; std::experimental::optional vertices_it_; GraphDbAccessor &db_; @@ -258,7 +258,7 @@ ScanAll::ScanAll(const std::shared_ptr &input, ACCEPT_WITH_INPUT(ScanAll) std::unique_ptr ScanAll::MakeCursor(GraphDbAccessor &db) { - auto vertices = [this, &db](Frame &, const SymbolTable &) { + auto vertices = [this, &db](Frame &, Context &) { return db.Vertices(graph_view_ == GraphView::NEW); }; return std::make_unique>( @@ -273,7 +273,7 @@ ScanAllByLabel::ScanAllByLabel(const std::shared_ptr &input, ACCEPT_WITH_INPUT(ScanAllByLabel) std::unique_ptr ScanAllByLabel::MakeCursor(GraphDbAccessor &db) { - auto vertices = [this, &db](Frame &, const SymbolTable &) { + auto vertices = [this, &db](Frame &, Context &) { return db.Vertices(label_, graph_view_ == GraphView::NEW); }; return std::make_unique>( @@ -308,15 +308,15 @@ std::unique_ptr ScanAllByLabelPropertyRange::MakeCursor( } return false; }; - auto vertices = [this, &db, is_less](Frame &frame, - const SymbolTable &symbol_table) { - ExpressionEvaluator evaluator(frame, symbol_table, db, graph_view_); + auto vertices = [this, &db, is_less](Frame &frame, Context &context) { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db, graph_view_); auto convert = [&evaluator](const auto &bound) -> std::experimental::optional> { - if (!bound) return std::experimental::nullopt; - return std::experimental::make_optional(utils::Bound( - bound.value().value()->Accept(evaluator), bound.value().type())); - }; + if (!bound) return std::experimental::nullopt; + return std::experimental::make_optional(utils::Bound( + bound.value().value()->Accept(evaluator), bound.value().type())); + }; return db.Vertices(label_, property_, convert(lower_bound()), convert(upper_bound()), graph_view_ == GraphView::NEW); }; @@ -343,14 +343,15 @@ class ScanAllByLabelPropertyValueCursor : public Cursor { GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input()->MakeCursor(db_)) {} - bool Pull(Frame &frame, const SymbolTable &symbol_table) override { + bool Pull(Frame &frame, Context &context) override { if (db_.should_abort()) throw HintedAbortError(); if (!vertices_ || vertices_it_.value() == vertices_.value().end()) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; - ExpressionEvaluator evaluator(frame, symbol_table, db_, + if (!input_cursor_->Pull(frame, context)) return false; + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, self_.graph_view()); TypedValue value = self_.expression()->Accept(evaluator); - if (value.IsNull()) return Pull(frame, symbol_table); + if (value.IsNull()) return Pull(frame, context); try { vertices_.emplace(db_.Vertices(self_.label(), self_.property(), value, self_.graph_view() == GraphView::NEW)); @@ -429,7 +430,7 @@ std::unique_ptr Expand::MakeCursor(GraphDbAccessor &db) { Expand::ExpandCursor::ExpandCursor(const Expand &self, GraphDbAccessor &db) : self_(self), input_cursor_(self.input_->MakeCursor(db)), db_(db) {} -bool Expand::ExpandCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { +bool Expand::ExpandCursor::Pull(Frame &frame, Context &context) { // Helper function for handling existing-edge checking. Returns false only if // existing_edge is true and the given new_edge is not equal to the existing // one. @@ -492,7 +493,7 @@ bool Expand::ExpandCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { // if we are here, either the edges have not been initialized, // or they have been exhausted. attempt to initialize the edges, // if the input is exhausted - if (!InitEdges(frame, symbol_table)) return false; + if (!InitEdges(frame, context)) return false; // we have re-initialized the edges, continue with the loop } @@ -523,12 +524,11 @@ void SwitchAccessor(TAccessor &accessor, GraphView graph_view) { } } -bool Expand::ExpandCursor::InitEdges(Frame &frame, - const SymbolTable &symbol_table) { +bool Expand::ExpandCursor::InitEdges(Frame &frame, Context &context) { // Input Vertex could be null if it is created by a failed optional match. In // those cases we skip that input pull and continue with the next. while (true) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; + if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Null check due to possible failed optional match. @@ -681,12 +681,14 @@ class ExpandVariableCursor : public Cursor { ExpandVariableCursor(const ExpandVariable &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} - bool Pull(Frame &frame, const SymbolTable &symbol_table) override { - ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_); + bool Pull(Frame &frame, Context &context) override { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, + self_.graph_view_); while (true) { - if (Expand(frame, symbol_table)) return true; + if (Expand(frame, context)) return true; - if (PullInput(frame, symbol_table)) { + if (PullInput(frame, context)) { // if lower bound is zero we also yield empty paths if (lower_bound_ == 0) { auto &edges_on_frame = @@ -745,11 +747,11 @@ class ExpandVariableCursor : public Cursor { * @return If the Pull succeeded. If not, this VariableExpandCursor * is exhausted. */ - bool PullInput(Frame &frame, const SymbolTable &symbol_table) { + bool PullInput(Frame &frame, Context &context) { // Input Vertex could be null if it is created by a failed optional match. // In those cases we skip that input pull and continue with the next. while (true) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; + if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Null check due to possible failed optional match. @@ -760,7 +762,8 @@ class ExpandVariableCursor : public Cursor { SwitchAccessor(vertex, self_.graph_view_); // Evaluate the upper and lower bounds. - ExpressionEvaluator evaluator(frame, symbol_table, db_); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_); auto calc_bound = [this, &evaluator](auto &bound) { auto value = EvaluateInt(evaluator, bound, "Variable expansion bound"); if (value < 0) @@ -869,8 +872,10 @@ class ExpandVariableCursor : public Cursor { * case no more expansions are available from the current input * vertex and another Pull from the input cursor should be performed. */ - bool Expand(Frame &frame, const SymbolTable &symbol_table) { - ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_); + bool Expand(Frame &frame, Context &context) { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, + self_.graph_view_); // some expansions might not be valid due to // edge uniqueness, existing_edge, existing_node criterions, // so expand in a loop until either the input vertex is @@ -994,10 +999,10 @@ ExpandBreadthFirst::Cursor::Cursor(const ExpandBreadthFirst &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool ExpandBreadthFirst::Cursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { +bool ExpandBreadthFirst::Cursor::Pull(Frame &frame, Context &context) { // evaulator for the filtering condition - ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, self_.graph_view_); // for the given (edge, vertex) pair checks if they satisfy the // "where" condition. if so, places them in the to_visit_ structure. @@ -1061,7 +1066,7 @@ bool ExpandBreadthFirst::Cursor::Pull(Frame &frame, // if current is still empty, it means both are empty, so pull from input if (to_visit_current_.empty()) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; + if (!input_cursor_->Pull(frame, context)) return false; processed_.clear(); auto vertex_value = frame[self_.input_symbol_]; @@ -1141,11 +1146,12 @@ std::unique_ptr Filter::MakeCursor(GraphDbAccessor &db) { Filter::FilterCursor::FilterCursor(const Filter &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool Filter::FilterCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { +bool Filter::FilterCursor::Pull(Frame &frame, Context &context) { // Like all filters, newly set values should not affect filtering of old nodes // and edges. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::OLD); - while (input_cursor_->Pull(frame, symbol_table)) { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::OLD); + while (input_cursor_->Pull(frame, context)) { if (EvaluateFilter(evaluator, self_.expression_)) return true; } return false; @@ -1179,11 +1185,11 @@ const std::vector &Produce::named_expressions() { Produce::ProduceCursor::ProduceCursor(const Produce &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool Produce::ProduceCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (input_cursor_->Pull(frame, symbol_table)) { +bool Produce::ProduceCursor::Pull(Frame &frame, Context &context) { + if (input_cursor_->Pull(frame, context)) { // Produce should always yield the latest results. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); for (auto named_expr : self_.named_expressions_) named_expr->Accept(evaluator); return true; @@ -1206,12 +1212,13 @@ std::unique_ptr Delete::MakeCursor(GraphDbAccessor &db) { Delete::DeleteCursor::DeleteCursor(const Delete &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool Delete::DeleteCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool Delete::DeleteCursor::Pull(Frame &frame, Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; // Delete should get the latest information, this way it is also possible to // delete newly added nodes and edges. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); // collect expressions results so edges can get deleted before vertices // this is necessary because an edge that gets deleted could block vertex // deletion @@ -1270,12 +1277,12 @@ SetProperty::SetPropertyCursor::SetPropertyCursor(const SetProperty &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool SetProperty::SetPropertyCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool SetProperty::SetPropertyCursor::Pull(Frame &frame, Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; // Set, just like Create needs to see the latest changes. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); TypedValue rhs = self_.rhs_->Accept(evaluator); @@ -1318,14 +1325,14 @@ SetProperties::SetPropertiesCursor::SetPropertiesCursor( const SetProperties &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool SetProperties::SetPropertiesCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool SetProperties::SetPropertiesCursor::Pull(Frame &frame, Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; TypedValue &lhs = frame[self_.input_symbol_]; // Set, just like Create needs to see the latest changes. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); TypedValue rhs = self_.rhs_->Accept(evaluator); switch (lhs.type()) { @@ -1396,9 +1403,8 @@ SetLabels::SetLabelsCursor::SetLabelsCursor(const SetLabels &self, GraphDbAccessor &db) : self_(self), input_cursor_(self.input_->MakeCursor(db)) {} -bool SetLabels::SetLabelsCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool SetLabels::SetLabelsCursor::Pull(Frame &frame, Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Skip setting labels on Null (can occur in optional match). @@ -1427,12 +1433,13 @@ RemoveProperty::RemovePropertyCursor::RemovePropertyCursor( const RemoveProperty &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool RemoveProperty::RemovePropertyCursor::Pull( - Frame &frame, const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool RemoveProperty::RemovePropertyCursor::Pull(Frame &frame, + Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; // Remove, just like Delete needs to see the latest changes. - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); TypedValue lhs = self_.lhs_->expression_->Accept(evaluator); switch (lhs.type()) { @@ -1469,9 +1476,8 @@ RemoveLabels::RemoveLabelsCursor::RemoveLabelsCursor(const RemoveLabels &self, GraphDbAccessor &db) : self_(self), input_cursor_(self.input_->MakeCursor(db)) {} -bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; +bool RemoveLabels::RemoveLabelsCursor::Pull(Frame &frame, Context &context) { + if (!input_cursor_->Pull(frame, context)) return false; TypedValue &vertex_value = frame[self_.input_symbol_]; // Skip removing labels on Null (can occur in optional match). @@ -1541,7 +1547,7 @@ bool ContainsSame(const TypedValue &a, const TypedValue &b) { template bool ExpandUniquenessFilter::ExpandUniquenessFilterCursor::Pull( - Frame &frame, const SymbolTable &symbol_table) { + Frame &frame, Context &context) { auto expansion_ok = [&]() { TypedValue &expand_value = frame[self_.expand_symbol_]; for (const auto &previous_symbol : self_.previous_symbols_) { @@ -1555,7 +1561,7 @@ bool ExpandUniquenessFilter::ExpandUniquenessFilterCursor::Pull( return true; }; - while (input_cursor_->Pull(frame, symbol_table)) + while (input_cursor_->Pull(frame, context)) if (expansion_ok()) return true; return false; } @@ -1621,11 +1627,10 @@ Accumulate::AccumulateCursor::AccumulateCursor(const Accumulate &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool Accumulate::AccumulateCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { +bool Accumulate::AccumulateCursor::Pull(Frame &frame, Context &context) { // cache all the input if (!pulled_all_input_) { - while (input_cursor_->Pull(frame, symbol_table)) { + while (input_cursor_->Pull(frame, context)) { std::vector row; row.reserve(self_.symbols_.size()); for (const Symbol &symbol : self_.symbols_) @@ -1696,10 +1701,9 @@ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element) { } } -bool Aggregate::AggregateCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { +bool Aggregate::AggregateCursor::Pull(Frame &frame, Context &context) { if (!pulled_all_input_) { - ProcessAll(frame, symbol_table); + ProcessAll(frame, context); pulled_all_input_ = true; aggregation_it_ = aggregation_.begin(); @@ -1732,11 +1736,11 @@ bool Aggregate::AggregateCursor::Pull(Frame &frame, return true; } -void Aggregate::AggregateCursor::ProcessAll(Frame &frame, - const SymbolTable &symbol_table) { - ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::NEW); - while (input_cursor_->Pull(frame, symbol_table)) - ProcessOne(frame, symbol_table, evaluator); +void Aggregate::AggregateCursor::ProcessAll(Frame &frame, Context &context) { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_, GraphView::NEW); + while (input_cursor_->Pull(frame, context)) + ProcessOne(frame, context.symbol_table_, evaluator); // calculate AVG aggregations (so far they have only been summed) for (int pos = 0; pos < static_cast(self_.aggregations_.size()); ++pos) { @@ -1951,12 +1955,13 @@ std::vector Skip::OutputSymbols(const SymbolTable &symbol_table) { Skip::SkipCursor::SkipCursor(Skip &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool Skip::SkipCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { - while (input_cursor_->Pull(frame, symbol_table)) { +bool Skip::SkipCursor::Pull(Frame &frame, Context &context) { + while (input_cursor_->Pull(frame, context)) { if (to_skip_ == -1) { // first successful pull from the input // evaluate the skip expression - ExpressionEvaluator evaluator(frame, symbol_table, db_); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_); TypedValue to_skip = self_.expression_->Accept(evaluator); if (to_skip.type() != TypedValue::Type::Int) throw QueryRuntimeException("Result of SKIP expression must be an int"); @@ -1997,13 +2002,14 @@ std::vector Limit::OutputSymbols(const SymbolTable &symbol_table) { Limit::LimitCursor::LimitCursor(Limit &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool Limit::LimitCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { +bool Limit::LimitCursor::Pull(Frame &frame, Context &context) { // we need to evaluate the limit expression before the first input Pull // because it might be 0 and thereby we shouldn't Pull from input at all // we can do this before Pulling from the input because the limit expression // is not allowed to contain any identifiers if (limit_ == -1) { - ExpressionEvaluator evaluator(frame, symbol_table, db_); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_); TypedValue limit = self_.expression_->Accept(evaluator); if (limit.type() != TypedValue::Type::Int) throw QueryRuntimeException("Result of LIMIT expression must be an int"); @@ -2017,7 +2023,7 @@ bool Limit::LimitCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { // check we have not exceeded the limit before pulling if (pulled_++ >= limit_) return false; - return input_cursor_->Pull(frame, symbol_table); + return input_cursor_->Pull(frame, context); } void Limit::LimitCursor::Reset() { @@ -2055,11 +2061,11 @@ std::vector OrderBy::OutputSymbols(const SymbolTable &symbol_table) { OrderBy::OrderByCursor::OrderByCursor(OrderBy &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} -bool OrderBy::OrderByCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { +bool OrderBy::OrderByCursor::Pull(Frame &frame, Context &context) { if (!did_pull_all_) { - ExpressionEvaluator evaluator(frame, symbol_table, db_); - while (input_cursor_->Pull(frame, symbol_table)) { + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_); + while (input_cursor_->Pull(frame, context)) { // collect the order_by elements std::vector order_by; order_by.reserve(self_.order_by_.size()); @@ -2196,9 +2202,9 @@ Merge::MergeCursor::MergeCursor(Merge &self, GraphDbAccessor &db) merge_match_cursor_(self.merge_match_->MakeCursor(db)), merge_create_cursor_(self.merge_create_->MakeCursor(db)) {} -bool Merge::MergeCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { +bool Merge::MergeCursor::Pull(Frame &frame, Context &context) { if (pull_input_) { - if (input_cursor_->Pull(frame, symbol_table)) { + if (input_cursor_->Pull(frame, context)) { // after a successful input from the input // reset merge_match (it's expand iterators maintain state) // and merge_create (could have a Once at the beginning) @@ -2210,7 +2216,7 @@ bool Merge::MergeCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { } // pull from the merge_match cursor - if (merge_match_cursor_->Pull(frame, symbol_table)) { + if (merge_match_cursor_->Pull(frame, context)) { // if successful, next Pull from this should not pull_input_ pull_input_ = false; return true; @@ -2220,14 +2226,14 @@ bool Merge::MergeCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { // if we have just now pulled from the input // and failed to pull from merge_match, we should create __attribute__((unused)) bool merge_create_pull_result = - merge_create_cursor_->Pull(frame, symbol_table); + merge_create_cursor_->Pull(frame, context); debug_assert(merge_create_pull_result, "MergeCreate must never fail"); return true; } // we have exhausted merge_match_cursor_ after 1 or more successful Pulls // attempt next input_cursor_ pull pull_input_ = true; - return Pull(frame, symbol_table); + return Pull(frame, context); } } @@ -2261,10 +2267,9 @@ Optional::OptionalCursor::OptionalCursor(Optional &self, GraphDbAccessor &db) input_cursor_(self.input_->MakeCursor(db)), optional_cursor_(self.optional_->MakeCursor(db)) {} -bool Optional::OptionalCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { +bool Optional::OptionalCursor::Pull(Frame &frame, Context &context) { if (pull_input_) { - if (input_cursor_->Pull(frame, symbol_table)) { + if (input_cursor_->Pull(frame, context)) { // after a successful input from the input // reset optional_ (it's expand iterators maintain state) optional_cursor_->Reset(); @@ -2274,7 +2279,7 @@ bool Optional::OptionalCursor::Pull(Frame &frame, } // pull from the optional_ cursor - if (optional_cursor_->Pull(frame, symbol_table)) { + if (optional_cursor_->Pull(frame, context)) { // if successful, next Pull from this should not pull_input_ pull_input_ = false; return true; @@ -2293,7 +2298,7 @@ bool Optional::OptionalCursor::Pull(Frame &frame, // we have exhausted optional_cursor_ after 1 or more successful Pulls // attempt next input_cursor_ pull pull_input_ = true; - return Pull(frame, symbol_table); + return Pull(frame, context); } } @@ -2318,15 +2323,16 @@ std::unique_ptr Unwind::MakeCursor(GraphDbAccessor &db) { Unwind::UnwindCursor::UnwindCursor(Unwind &self, GraphDbAccessor &db) : self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {} -bool Unwind::UnwindCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { +bool Unwind::UnwindCursor::Pull(Frame &frame, Context &context) { if (db_.should_abort()) throw HintedAbortError(); // if we reached the end of our list of values // pull from the input if (input_value_it_ == input_value_.end()) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; + if (!input_cursor_->Pull(frame, context)) return false; // successful pull from input, initialize value and iterator - ExpressionEvaluator evaluator(frame, symbol_table, db_); + ExpressionEvaluator evaluator(frame, context.parameters_, + context.symbol_table_, db_); TypedValue input_value = self_.input_expression_->Accept(evaluator); if (input_value.type() != TypedValue::Type::List) throw QueryRuntimeException("UNWIND only accepts list values, got '{}'", @@ -2336,7 +2342,7 @@ bool Unwind::UnwindCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { } // if we reached the end of our list of values goto back to top - if (input_value_it_ == input_value_.end()) return Pull(frame, symbol_table); + if (input_value_it_ == input_value_.end()) return Pull(frame, context); frame[self_.output_symbol_] = *input_value_it_++; return true; @@ -2367,10 +2373,9 @@ std::vector Distinct::OutputSymbols(const SymbolTable &symbol_table) { Distinct::DistinctCursor::DistinctCursor(Distinct &self, GraphDbAccessor &db) : self_(self), input_cursor_(self.input_->MakeCursor(db)) {} -bool Distinct::DistinctCursor::Pull(Frame &frame, - const SymbolTable &symbol_table) { +bool Distinct::DistinctCursor::Pull(Frame &frame, Context &context) { while (true) { - if (!input_cursor_->Pull(frame, symbol_table)) return false; + if (!input_cursor_->Pull(frame, context)) return false; std::vector row; row.reserve(self_.value_symbols_.size()); @@ -2398,7 +2403,7 @@ class CreateIndexCursor : public Cursor { CreateIndexCursor(CreateIndex &self, GraphDbAccessor &db) : self_(self), db_(db) {} - bool Pull(Frame &, const SymbolTable &) override { + bool Pull(Frame &, Context &) override { if (did_create_) return false; try { db_.BuildIndex(self_.label(), self_.property()); diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index fadee48a0..9ae4ea10b 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -24,8 +24,9 @@ namespace query { -class Frame; +class Context; class ExpressionEvaluator; +class Frame; namespace plan { @@ -43,9 +44,10 @@ class Cursor { * * @param Frame May be read from or written to while performing the * iteration. - * @param SymbolTable Used to get the position of symbols in frame. + * @param Context Used to get the position of symbols in frame and other + * information. */ - virtual bool Pull(Frame &, const SymbolTable &) = 0; + virtual bool Pull(Frame &, Context &) = 0; /** * Resets the Cursor to it's initial state. @@ -156,7 +158,7 @@ class Once : public LogicalOperator { private: class OnceCursor : public Cursor { public: - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -194,7 +196,7 @@ class CreateNode : public LogicalOperator { class CreateNodeCursor : public Cursor { public: CreateNodeCursor(const CreateNode &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -205,7 +207,7 @@ class CreateNode : public LogicalOperator { /** * Creates a single node and places it in the frame. */ - void Create(Frame &frame, const SymbolTable &symbol_table); + void Create(Frame &, Context &); }; }; @@ -259,7 +261,7 @@ class CreateExpand : public LogicalOperator { class CreateExpandCursor : public Cursor { public: CreateExpandCursor(const CreateExpand &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -553,7 +555,7 @@ class Expand : public LogicalOperator, public ExpandCommon { class ExpandCursor : public Cursor { public: ExpandCursor(const Expand &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -569,7 +571,7 @@ class Expand : public LogicalOperator, public ExpandCommon { std::experimental::optional out_edges_; std::experimental::optional out_edges_it_; - bool InitEdges(Frame &frame, const SymbolTable &symbol_table); + bool InitEdges(Frame &, Context &); }; }; @@ -678,7 +680,7 @@ class ExpandBreadthFirst : public LogicalOperator { class Cursor : public query::plan::Cursor { public: Cursor(const ExpandBreadthFirst &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -751,7 +753,7 @@ class Filter : public LogicalOperator { class FilterCursor : public Cursor { public: FilterCursor(const Filter &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -788,7 +790,7 @@ class Produce : public LogicalOperator { class ProduceCursor : public Cursor { public: ProduceCursor(const Produce &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -822,7 +824,7 @@ class Delete : public LogicalOperator { class DeleteCursor : public Cursor { public: DeleteCursor(const Delete &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -853,7 +855,7 @@ class SetProperty : public LogicalOperator { class SetPropertyCursor : public Cursor { public: SetPropertyCursor(const SetProperty &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -901,7 +903,7 @@ class SetProperties : public LogicalOperator { class SetPropertiesCursor : public Cursor { public: SetPropertiesCursor(const SetProperties &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -940,7 +942,7 @@ class SetLabels : public LogicalOperator { class SetLabelsCursor : public Cursor { public: SetLabelsCursor(const SetLabels &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -967,7 +969,7 @@ class RemoveProperty : public LogicalOperator { class RemovePropertyCursor : public Cursor { public: RemovePropertyCursor(const RemoveProperty &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -999,7 +1001,7 @@ class RemoveLabels : public LogicalOperator { class RemoveLabelsCursor : public Cursor { public: RemoveLabelsCursor(const RemoveLabels &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1049,7 +1051,7 @@ class ExpandUniquenessFilter : public LogicalOperator { public: ExpandUniquenessFilterCursor(const ExpandUniquenessFilter &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1102,7 +1104,7 @@ class Accumulate : public LogicalOperator { class AccumulateCursor : public Cursor { public: AccumulateCursor(const Accumulate &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1170,7 +1172,7 @@ class Aggregate : public LogicalOperator { class AggregateCursor : public Cursor { public: AggregateCursor(Aggregate &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1218,7 +1220,7 @@ class Aggregate : public LogicalOperator { * cache cardinality depends on number of * aggregation results, and not on the number of inputs. */ - void ProcessAll(Frame &frame, const SymbolTable &symbol_table); + void ProcessAll(Frame &, Context &); /** * Performs a single accumulation. @@ -1273,7 +1275,7 @@ class Skip : public LogicalOperator { class SkipCursor : public Cursor { public: SkipCursor(Skip &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1316,7 +1318,7 @@ class Limit : public LogicalOperator { class LimitCursor : public Cursor { public: LimitCursor(Limit &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1384,7 +1386,7 @@ class OrderBy : public LogicalOperator { class OrderByCursor : public Cursor { public: OrderByCursor(OrderBy &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1436,7 +1438,7 @@ class Merge : public LogicalOperator { class MergeCursor : public Cursor { public: MergeCursor(Merge &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1481,7 +1483,7 @@ class Optional : public LogicalOperator { class OptionalCursor : public Cursor { public: OptionalCursor(Optional &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1520,7 +1522,7 @@ class Unwind : public LogicalOperator { class UnwindCursor : public Cursor { public: UnwindCursor(Unwind &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: @@ -1558,7 +1560,7 @@ class Distinct : public LogicalOperator { public: DistinctCursor(Distinct &self, GraphDbAccessor &db); - bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + bool Pull(Frame &, Context &) override; void Reset() override; private: diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 599f06d03..e410064cf 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -116,6 +116,7 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { } bool Visit(PrimitiveLiteral &) override { return true; } + bool Visit(ParameterLookup &) override { return true; } bool Visit(query::CreateIndex &) override { return true; } std::unordered_set symbols_; @@ -376,6 +377,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } + bool Visit(ParameterLookup &) override { return true; } bool Visit(query::CreateIndex &) override { return true; } // Creates NamedExpression with an Identifier for each user declared symbol. diff --git a/tests/manual/query_planner.cpp b/tests/manual/query_planner.cpp index ffac8023e..a5428f4a2 100644 --- a/tests/manual/query_planner.cpp +++ b/tests/manual/query_planner.cpp @@ -588,8 +588,7 @@ void ExaminePlans( } query::AstTreeStorage MakeAst(const std::string &query, GraphDbAccessor &dba) { - query::Config config; - query::Context ctx(config, dba); + query::Context ctx(dba); // query -> AST auto parser = std::make_unique(query); // AST -> high level tree diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 92f853c8b..9c5d5b7de 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -31,7 +31,7 @@ class Base { Base(const std::string &query) : query_string_(query) {} Dbms dbms_; std::unique_ptr db_accessor_ = dbms_.active(); - Context context_{Config{}, *db_accessor_}; + Context context_{*db_accessor_}; std::string query_string_; auto Prop(const std::string &prop_name) { diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index aadc2c0fb..44afc4fde 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -31,7 +31,8 @@ struct NoContextExpressionEvaluator { SymbolTable symbol_table; Dbms dbms; std::unique_ptr dba = dbms.active(); - ExpressionEvaluator eval{frame, symbol_table, *dba}; + Parameters parameters; + ExpressionEvaluator eval{frame, parameters, symbol_table, *dba}; }; TypedValue EvaluateFunction(const std::string &function_name, @@ -708,7 +709,8 @@ TEST(ExpressionEvaluator, Aggregation) { frame[aggr_sym] = TypedValue(1); Dbms dbms; auto dba = dbms.active(); - ExpressionEvaluator eval{frame, symbol_table, *dba}; + Parameters parameters; + ExpressionEvaluator eval{frame, parameters, symbol_table, *dba}; auto value = aggr->Accept(eval); EXPECT_EQ(value.Value(), 1); } @@ -1181,4 +1183,15 @@ TEST(ExpressionEvaluator, FunctionAssert) { ASSERT_TRUE(EvaluateFunction("ASSERT", {true}).ValueBool()); ASSERT_TRUE(EvaluateFunction("ASSERT", {true, "message"}).ValueBool()); } + +TEST(ExpressionEvaluator, ParameterLookup) { + NoContextExpressionEvaluator eval; + eval.parameters.Add(0, 42); + AstTreeStorage storage; + auto *param_lookup = storage.Create(0); + auto value = param_lookup->Accept(eval.eval); + ASSERT_EQ(value.type(), TypedValue::Type::Int); + EXPECT_EQ(value.Value(), 42); } + +} // namespace diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index ccea813c2..e2f4d07b7 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -11,6 +11,7 @@ #include "communication/result_stream_faker.hpp" #include "query/common.hpp" +#include "query/context.hpp" #include "query/frontend/semantic/symbol_table.hpp" #include "query/interpret/frame.hpp" #include "query/plan/operator.hpp" @@ -50,9 +51,11 @@ std::vector> CollectProduce( for (auto named_expression : produce->named_expressions()) symbols.emplace_back(symbol_table[*named_expression]); + Context context(db_accessor); + context.symbol_table_ = symbol_table; // stream out results auto cursor = produce->MakeCursor(db_accessor); - while (cursor->Pull(frame, symbol_table)) { + while (cursor->Pull(frame, context)) { std::vector values; for (auto &symbol : symbols) values.emplace_back(frame[symbol]); stream.Result(values); @@ -68,7 +71,9 @@ int PullAll(std::shared_ptr logical_op, GraphDbAccessor &db, Frame frame(symbol_table.max_position()); auto cursor = logical_op->MakeCursor(db); int count = 0; - while (cursor->Pull(frame, symbol_table)) count++; + Context context(db); + context.symbol_table_ = symbol_table; + while (cursor->Pull(frame, context)) count++; return count; } diff --git a/tests/unit/query_plan_create_set_remove_delete.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp index 68eb55541..5a4d7d11d 100644 --- a/tests/unit/query_plan_create_set_remove_delete.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -290,7 +290,9 @@ TEST(QueryPlan, Delete) { auto delete_op = std::make_shared( n.op_, std::vector{n_get}, true); Frame frame(symbol_table.max_position()); - delete_op->MakeCursor(*dba)->Pull(frame, symbol_table); + Context context(*dba); + context.symbol_table_ = symbol_table; + delete_op->MakeCursor(*dba)->Pull(frame, context); dba->AdvanceCommand(); EXPECT_EQ(3, CountIterable(dba->Vertices(false))); EXPECT_EQ(3, CountIterable(dba->Edges(false))); diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp index 7e0517bd4..aa187f89c 100644 --- a/tests/unit/query_plan_match_filter_return.cpp +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -394,7 +394,9 @@ class QueryPlanExpandVariable : public testing::Test { map_int count_per_length; Frame frame(symbol_table.max_position()); auto cursor = input_op->MakeCursor(*dba); - while (cursor->Pull(frame, symbol_table)) { + Context context(*dba); + context.symbol_table_ = symbol_table; + while (cursor->Pull(frame, context)) { auto length = frame[symbol].Value>().size(); auto found = count_per_length.find(length); if (found == count_per_length.end()) @@ -703,7 +705,9 @@ class QueryPlanExpandBreadthFirst : public testing::Test { Frame frame(symbol_table.max_position()); auto cursor = last_op->MakeCursor(*dba); std::vector, VertexAccessor>> results; - while (cursor->Pull(frame, symbol_table)) { + Context context(*dba); + context.symbol_table_ = symbol_table; + while (cursor->Pull(frame, context)) { results.emplace_back(std::vector(), frame[node_sym].Value()); for (const TypedValue &edge : frame[edge_list_sym].ValueList())