diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 4977ef5a6..05d39d282 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -67,7 +67,6 @@ class OrOperator : public BinaryOperator { void Accept(TreeVisitorBase &visitor) override { if (visitor.PreVisit(*this)) { visitor.Visit(*this); - // TODO: Should we short-circuit? expression1_->Accept(visitor); expression2_->Accept(visitor); visitor.PostVisit(*this); @@ -102,7 +101,27 @@ class AndOperator : public BinaryOperator { void Accept(TreeVisitorBase &visitor) override { if (visitor.PreVisit(*this)) { visitor.Visit(*this); - // TODO: Should we short-circuit? + expression1_->Accept(visitor); + expression2_->Accept(visitor); + visitor.PostVisit(*this); + } + } + + protected: + using BinaryOperator::BinaryOperator; +}; + +// This is separate operator so that we can implement different short-circuiting +// semantics than regular AndOperator. At this point CypherMainVisitor shouldn't +// concern itself with this, and should constructor only AndOperator-s. This is +// used in query planner at the moment. +class FilterAndOperator : public BinaryOperator { + friend class AstTreeStorage; + + public: + void Accept(TreeVisitorBase &visitor) override { + if (visitor.PreVisit(*this)) { + visitor.Visit(*this); expression1_->Accept(visitor); expression2_->Accept(visitor); visitor.PostVisit(*this); diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 09a257d6d..d6cb96a8a 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -25,6 +25,7 @@ class ListLiteral; class OrOperator; class XorOperator; class AndOperator; +class FilterAndOperator; class NotOperator; class AdditionOperator; class SubtractionOperator; @@ -54,15 +55,15 @@ class Merge; class Unwind; using TreeVisitorBase = ::utils::Visitor< - Query, NamedExpression, OrOperator, XorOperator, AndOperator, NotOperator, - AdditionOperator, SubtractionOperator, MultiplicationOperator, - DivisionOperator, ModOperator, NotEqualOperator, EqualOperator, - LessOperator, GreaterOperator, LessEqualOperator, GreaterEqualOperator, - InListOperator, ListIndexingOperator, ListSlicingOperator, - UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, Identifier, - PrimitiveLiteral, ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, - Aggregation, Function, Create, Match, Return, With, Pattern, NodeAtom, - EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, + Query, NamedExpression, OrOperator, XorOperator, AndOperator, + FilterAndOperator, NotOperator, AdditionOperator, SubtractionOperator, + MultiplicationOperator, DivisionOperator, ModOperator, NotEqualOperator, + EqualOperator, LessOperator, GreaterOperator, LessEqualOperator, + GreaterEqualOperator, InListOperator, ListIndexingOperator, + ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, + Identifier, PrimitiveLiteral, ListLiteral, PropertyLookup, LabelsTest, + EdgeTypeTest, Aggregation, Function, Create, Match, Return, With, Pattern, + NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>; } // namespace query diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index fca166976..0abdf7133 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -89,6 +89,20 @@ class ExpressionEvaluator : public TreeVisitorBase { #undef BINARY_OPERATOR_VISITOR #undef UNARY_OPERATOR_VISITOR + bool PreVisit(FilterAndOperator &op) override { + op.expression1_->Accept(*this); + auto expression1 = PopBack(); + if (expression1.IsNull() || !expression1.Value()) { + // If first expression is null or false, don't execute the second one. + result_stack_.emplace_back(expression1); + return false; + } + op.expression2_->Accept(*this); + auto expression2 = PopBack(); + result_stack_.emplace_back(expression2); + return false; + } + void PostVisit(InListOperator &) override { auto _list = PopBack(); auto literal = PopBack(); diff --git a/src/query/plan/planner.cpp b/src/query/plan/planner.cpp index 2a1a98216..d6f356112 100644 --- a/src/query/plan/planner.cpp +++ b/src/query/plan/planner.cpp @@ -147,7 +147,7 @@ Expression *PropertiesEqual(AstTreeStorage &storage, storage.Create(atom->identifier_, prop_pair.first); auto *prop_equal = storage.Create(property_lookup, prop_pair.second); - filter_expr = BoolJoin(storage, filter_expr, prop_equal); + filter_expr = BoolJoin(storage, filter_expr, prop_equal); } return filter_expr; } @@ -166,7 +166,7 @@ auto &CollectPatternFilters( if (labels_filter || props_filter) { collector.symbols_.insert(symbol_table.at(*node->identifier_).position_); filters.emplace_back( - BoolJoin(storage, labels_filter, props_filter), + BoolJoin(storage, labels_filter, props_filter), collector.symbols_); collector.symbols_.clear(); } @@ -183,7 +183,7 @@ auto &CollectPatternFilters( const auto &edge_symbol = symbol_table.at(*edge->identifier_); collector.symbols_.insert(edge_symbol.position_); filters->emplace_back( - BoolJoin(storage, types_filter, props_filter), + BoolJoin(storage, types_filter, props_filter), collector.symbols_); collector.symbols_.clear(); } @@ -235,7 +235,7 @@ auto GenFilters( for (auto filters_it = filters.begin(); filters_it != filters.end();) { if (HasBoundFilterSymbols(bound_symbols, *filters_it)) { filter_expr = - BoolJoin(storage, filter_expr, filters_it->first); + BoolJoin(storage, filter_expr, filters_it->first); filters_it = filters.erase(filters_it); } else { filters_it++; @@ -456,7 +456,8 @@ class ReturnBodyContext : public TreeVisitorBase { // Aggregation contains a virtual symbol, where the result will be stored. const auto &symbol = symbol_table_.at(aggr); aggregations_.emplace_back(aggr.expression_, aggr.op_, symbol); - // aggregation expression_ is opional in COUNT(*) so it's possible the has_aggregation_ stack is empty + // aggregation expression_ is opional in COUNT(*) so it's possible the + // has_aggregation_ stack is empty if (aggr.expression_) has_aggregation_.back() = true; else diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index c4e93bdb9..7cad2bae8 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -90,6 +90,32 @@ TEST(ExpressionEvaluator, AndOperator) { ASSERT_EQ(eval.eval.PopBack().Value(), false); } +TEST(ExpressionEvaluator, FilterAndOperator) { + AstTreeStorage storage; + NoContextExpressionEvaluator eval; + { + auto *op = storage.Create( + storage.Create(true), + storage.Create(true)); + op->Accept(eval.eval); + EXPECT_EQ(eval.eval.PopBack().Value(), true); + } + { + auto *op = storage.Create( + storage.Create(false), + storage.Create(5)); + op->Accept(eval.eval); + EXPECT_EQ(eval.eval.PopBack().Value(), false); + } + { + auto *op = storage.Create( + storage.Create(TypedValue::Null), + storage.Create(5)); + op->Accept(eval.eval); + EXPECT_TRUE(eval.eval.PopBack().IsNull()); + } +} + TEST(ExpressionEvaluator, AdditionOperator) { AstTreeStorage storage; NoContextExpressionEvaluator eval;