From 541c3f0af7d67fed0159171ad6ed66be44ce4ea4 Mon Sep 17 00:00:00 2001 From: florijan Date: Tue, 18 Apr 2017 15:19:42 +0200 Subject: [PATCH] Query::Plan - Skip and Limit added Reviewers: mislav.bradac, teon.banek Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D292 --- src/query/frontend/ast/ast.hpp | 8 ++ src/query/plan/operator.cpp | 87 +++++++++++++++++++ src/query/plan/operator.hpp | 81 +++++++++++++++++- tests/unit/query_plan_bag_semantics.cpp | 106 ++++++++++++++++++++++++ tests/unit/query_plan_common.hpp | 2 +- 5 files changed, 282 insertions(+), 2 deletions(-) create mode 100644 tests/unit/query_plan_bag_semantics.cpp diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index cc8772e2c..e1717618a 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -658,10 +658,14 @@ class Return : public Clause { for (auto &expr : named_expressions_) { expr->Accept(visitor); } + if (skip_) skip_->Accept(visitor); + if (limit_) limit_->Accept(visitor); visitor.PostVisit(*this); } } std::vector named_expressions_; + Expression *skip_ = nullptr; + Expression *limit_ = nullptr; protected: Return(int uid) : Clause(uid) {} @@ -678,6 +682,8 @@ class With : public Clause { expr->Accept(visitor); } if (where_) where_->Accept(visitor); + if (skip_) skip_->Accept(visitor); + if (limit_) limit_->Accept(visitor); visitor.PostVisit(*this); } } @@ -685,6 +691,8 @@ class With : public Clause { bool distinct_ = false; std::vector named_expressions_; Where *where_ = nullptr; + Expression *skip_ = nullptr; + Expression *limit_ = nullptr; protected: With(int uid) : Clause(uid) {} diff --git a/src/query/plan/operator.cpp b/src/query/plan/operator.cpp index 62399738e..a77a4e8e5 100644 --- a/src/query/plan/operator.cpp +++ b/src/query/plan/operator.cpp @@ -1198,5 +1198,92 @@ bool Aggregate::AggregateCursor::TypedValueListEqual::operator()( TypedValue::BoolEqual{}); } +Skip::Skip(const std::shared_ptr &input, + Expression *expression) + : input_(input), expression_(expression) {} + +void Skip::Accept(LogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } +} + +std::unique_ptr Skip::MakeCursor(GraphDbAccessor &db) { + return std::make_unique(*this, db); +} + +Skip::SkipCursor::SkipCursor(Skip &self, GraphDbAccessor &db) + : self_(self), input_cursor_(self_.input_->MakeCursor(db)) {} + +bool Skip::SkipCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { + while (input_cursor_->Pull(frame, symbol_table)) { + if (to_skip_ == -1) { + // first successful pull from the input + // evaluate the skip expression + ExpressionEvaluator evaluator(frame, symbol_table); + self_.expression_->Accept(evaluator); + TypedValue to_skip = evaluator.PopBack(); + if (to_skip.type() != TypedValue::Type::Int) + throw QueryRuntimeException("Result of SKIP expression must be an int"); + + to_skip_ = to_skip.Value(); + if (to_skip_ < 0) + throw QueryRuntimeException( + "Result of SKIP expression must be greater or equal to zero"); + } + + if (skipped_++ < to_skip_) continue; + return true; + } + return false; +} + +Limit::Limit(const std::shared_ptr &input, + Expression *expression) + : input_(input), expression_(expression) {} + +void Limit::Accept(LogicalOperatorVisitor &visitor) { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); + input_->Accept(visitor); + visitor.PostVisit(*this); + } +} + +std::unique_ptr Limit::MakeCursor(GraphDbAccessor &db) { + return std::make_unique(*this, db); +} + +Limit::LimitCursor::LimitCursor(Limit &self, GraphDbAccessor &db) + : self_(self), input_cursor_(self_.input_->MakeCursor(db)) {} + +bool Limit::LimitCursor::Pull(Frame &frame, const SymbolTable &symbol_table) { + + // we need to evaluate the limit expression before the first input Pull + // because it might be 0 and thereby we shouldn't Pull from input at all + // we can do this before Pulling from the input because the limit expression + // is not allowed to contain any identifiers + if (limit_ == -1) { + ExpressionEvaluator evaluator(frame, symbol_table); + self_.expression_->Accept(evaluator); + TypedValue limit = evaluator.PopBack(); + if (limit.type() != TypedValue::Type::Int) + throw QueryRuntimeException("Result of LIMIT expression must be an int"); + + limit_ = limit.Value(); + if (limit_ < 0) + throw QueryRuntimeException( + "Result of LIMIT expression must be greater or equal to zero"); + } + + // check we have not exceeded the limit before pulling + if (pulled_++ >= limit_) + return false; + + return input_cursor_->Pull(frame, symbol_table); +} + } // namespace plan } // namespace query diff --git a/src/query/plan/operator.hpp b/src/query/plan/operator.hpp index ba7af213a..61ac6de94 100644 --- a/src/query/plan/operator.hpp +++ b/src/query/plan/operator.hpp @@ -61,6 +61,8 @@ class ExpandUniquenessFilter; class Accumulate; class AdvanceCommand; class Aggregate; +class Skip; +class Limit; /** @brief Base class for visitors of @c LogicalOperator class hierarchy. */ using LogicalOperatorVisitor = @@ -69,7 +71,7 @@ using LogicalOperatorVisitor = SetProperties, SetLabels, RemoveProperty, RemoveLabels, ExpandUniquenessFilter, ExpandUniquenessFilter, Accumulate, - AdvanceCommand, Aggregate>; + AdvanceCommand, Aggregate, Skip, Limit>; /** @brief Base class for logical operators. * @@ -929,5 +931,82 @@ class Aggregate : public LogicalOperator { }; }; +/** @brief Skips a number of Pulls from the input op. + * + * The given expression determines how many Pulls from the input + * should be skipped (ignored). + * All other successful Pulls from the + * input are simply passed through. + * + * The given expression is evaluated after the first Pull from + * the input, and only once. Neo does not allow this expression + * to contain identifiers, and neither does Memgraph, but this + * operator's implementation does not expect this. + */ +class Skip : public LogicalOperator { + public: + Skip(const std::shared_ptr &input, Expression *expression); + void Accept(LogicalOperatorVisitor &visitor) override; + std::unique_ptr MakeCursor(GraphDbAccessor &db) override; + + private: + const std::shared_ptr input_; + Expression *expression_; + + class SkipCursor : public Cursor { + public: + SkipCursor(Skip &self, GraphDbAccessor &db); + bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + + private: + Skip &self_; + std::unique_ptr input_cursor_; + // init to_skip_ to -1, indicating + // that it's still unknown (input has not been Pulled yet) + int to_skip_{-1}; + int skipped_{0}; + }; +}; + +/** @brief Limits the number of Pulls from the input op. + * + * The given expression determines how many + * input Pulls should be passed through. The input is not + * Pulled once this limit is reached. Note that this has + * implications: the out-of-bounds input Pulls are never + * evaluated. + * + * The limit expression must NOT use anything from the + * Frame. It is evaluated before the first Pull from the + * input. This is consistent with Neo (they don't allow + * identifiers in limit expressions), and it's necessary + * when limit evaluates to 0 (because 0 Pulls from the + * input should be performed). + */ +class Limit : public LogicalOperator { + public: + Limit(const std::shared_ptr &input, Expression *expression); + void Accept(LogicalOperatorVisitor &visitor) override; + std::unique_ptr MakeCursor(GraphDbAccessor &db) override; + + private: + const std::shared_ptr input_; + Expression *expression_; + + class LimitCursor : public Cursor { + public: + LimitCursor(Limit &self, GraphDbAccessor &db); + bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + + private: + Limit &self_; + std::unique_ptr input_cursor_; + // init limit_ to -1, indicating + // that it's still unknown (Cursor has not been Pulled yet) + int limit_{-1}; + int pulled_{0}; + }; +}; + } // namespace plan } // namespace query diff --git a/tests/unit/query_plan_bag_semantics.cpp b/tests/unit/query_plan_bag_semantics.cpp new file mode 100644 index 000000000..721e8ffda --- /dev/null +++ b/tests/unit/query_plan_bag_semantics.cpp @@ -0,0 +1,106 @@ +// +// Copyright 2017 Memgraph +// Created by Florijan Stamenkovic on 14.03.17. +// + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "communication/result_stream_faker.hpp" +#include "dbms/dbms.hpp" +#include "query/context.hpp" +#include "query/exceptions.hpp" +#include "query/plan/operator.hpp" + +#include "query_plan_common.hpp" + +using namespace query; +using namespace query::plan; + +TEST(QueryPlan, Skip) { + Dbms dbms; + auto dba = dbms.active(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto skip = std::make_shared(n.op_, LITERAL(2)); + + EXPECT_EQ(0, PullAll(skip, *dba, symbol_table)); + + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(0, PullAll(skip, *dba, symbol_table)); + + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(0, PullAll(skip, *dba, symbol_table)); + + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(1, PullAll(skip, *dba, symbol_table)); + + for (int i = 0; i < 10; ++i) + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(11, PullAll(skip, *dba, symbol_table)); +} + +TEST(QueryPlan, Limit) { + Dbms dbms; + auto dba = dbms.active(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto skip = std::make_shared(n.op_, LITERAL(2)); + + EXPECT_EQ(0, PullAll(skip, *dba, symbol_table)); + + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(1, PullAll(skip, *dba, symbol_table)); + + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(2, PullAll(skip, *dba, symbol_table)); + + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(2, PullAll(skip, *dba, symbol_table)); + + for (int i = 0; i < 10; ++i) + dba->insert_vertex(); + dba->advance_command(); + EXPECT_EQ(2, PullAll(skip, *dba, symbol_table)); +} + +TEST(QueryPlan, CreateLimit) { + // CREATE (n), (m) + // MATCH (n) CREATE (m) LIMIT 1 + // in the end we need to have 3 vertices in the db + Dbms dbms; + auto dba = dbms.active(); + dba->insert_vertex(); + dba->insert_vertex(); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n1"); + auto m = NODE("m"); + symbol_table[*m->identifier_] = symbol_table.CreateSymbol("m"); + auto c = std::make_shared(m, n.op_); + auto skip = std::make_shared(c, LITERAL(1)); + + EXPECT_EQ(1, PullAll(skip, *dba, symbol_table)); + dba->advance_command(); + EXPECT_EQ(3, CountIterable(dba->vertices())); +} diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp index f1b5e2489..696111ef4 100644 --- a/tests/unit/query_plan_common.hpp +++ b/tests/unit/query_plan_common.hpp @@ -57,7 +57,7 @@ auto CollectProduce(std::shared_ptr produce, SymbolTable &symbol_table, } int PullAll(std::shared_ptr logical_op, GraphDbAccessor &db, - SymbolTable symbol_table) { + SymbolTable &symbol_table) { Frame frame(symbol_table.max_position()); auto cursor = logical_op->MakeCursor(db); int count = 0;