diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f0ea4270..81b29bd0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Major Features and Improvements +* Support for `all` function in openCypher. * User specified transaction execution timeout. * Support for query parameters (except for parameters in place of property maps). diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md index 806638dac..c39cb5e5f 100644 --- a/docs/user_technical/open-cypher.md +++ b/docs/user_technical/open-cypher.md @@ -421,6 +421,7 @@ functions. `startsWith` | Check if the first argument starts with the second. `endsWith` | Check if the first argument ends with the second. `contains` | Check if the first argument has an element which is equal to the second argument. + `all` | Check if all elements of a list satisfy a predicate.<br/>The syntax is: `all(variable IN list WHERE predicate)`. #### Parameters diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 8f3ac8c9b..097864d7a 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1014,6 +1014,42 @@ class Where : public Tree { Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {} }; +class All : public Expression { + friend class AstTreeStorage; + + public: + DEFVISITABLE(TreeVisitor<TypedValue>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + All *Clone(AstTreeStorage &storage) const override { + return storage.Create<All>(identifier_->Clone(storage), + list_expression_->Clone(storage), + where_->Clone(storage)); + } + + Identifier *identifier_ = nullptr; + Expression *list_expression_ = nullptr; + Where *where_ = nullptr; + + protected: + All(int uid, Identifier *identifier, Expression *list_expression, + Where *where) + : Expression(uid), + identifier_(identifier), + list_expression_(list_expression), + where_(where) { + debug_assert(identifier, "identifier must not be nullptr"); + debug_assert(list_expression, "list_expression must not be nullptr"); + debug_assert(where, "where must not be nullptr"); + } +}; + class Match : public Clause { friend class AstTreeStorage; diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index e4d985530..a73db0202 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -13,6 +13,7 @@ class LabelsTest; class EdgeTypeTest; class Aggregation; class Function; +class All; class Create; class Match; class Return; @@ -63,9 +64,9 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< GreaterEqualOperator, InListOperator, ListIndexingOperator, ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation, - Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, - Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, - Merge, Unwind>; + Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, + Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, + RemoveLabels, Merge, Unwind>; using TreeLeafVisitor = ::utils::LeafVisitor<Identifier, PrimitiveLiteral, CreateIndex>; @@ -88,8 +89,8 @@ using TreeVisitor = ::utils::Visitor< GreaterEqualOperator, InListOperator, ListIndexingOperator, ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation, - Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, - Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, - Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>; + Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, + Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, + RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>; } // namespace query diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index 0937a1b79..5dff49bbd 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -756,6 +756,17 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) { // functionInvocation and atom producions in opencypher grammar. return static_cast<Expression *>( storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT)); + } else if (ctx->ALL()) { + auto *ident = storage_.Create<Identifier>(ctx->filterExpression() + ->idInColl() + ->variable() + ->accept(this) + .as<std::string>()); + Expression *list_expr = + ctx->filterExpression()->idInColl()->expression()->accept(this); + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast<Expression *>( + storage_.Create<All>(ident, list_expr, where)); } // TODO: Implement this. We don't support comprehensions, filtering... at // the moment. @@ -1024,4 +1035,10 @@ antlrcpp::Any CypherMainVisitor::visitUnwind(CypherParser::UnwindContext *ctx) { return storage_.Create<Unwind>(named_expr); } +antlrcpp::Any CypherMainVisitor::visitFilterExpression( + CypherParser::FilterExpressionContext *) { + debug_fail("Should never be called. See documentation in hpp."); + return 0; +} + } // namespace query::frontend diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 855a4d54a..37944aecf 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -530,6 +530,13 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor { */ antlrcpp::Any visitUnwind(CypherParser::UnwindContext *ctx) override; + /** + * Never call this. Ast generation for these expressions should be done by + * explicitly visiting the members of @c FilterExpressionContext. + */ + antlrcpp::Any visitFilterExpression( + CypherParser::FilterExpressionContext *) override; + public: Query *query() { return query_; } AstTreeStorage &storage() { return storage_; } diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 0919feb77..5d49fe2c2 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -4,8 +4,11 @@ #include "query/frontend/semantic/symbol_generator.hpp" +#include <experimental/optional> #include <unordered_set> +#include "utils/algorithm.hpp" + namespace query { auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared, @@ -265,6 +268,28 @@ bool SymbolGenerator::PostVisit(Aggregation &) { return true; } +bool SymbolGenerator::PreVisit(All &all) { + all.list_expression_->Accept(*this); + // Bind the new symbol after visiting the list expression. Keep the old symbol + // so it can be restored. + std::experimental::optional<Symbol> prev_symbol; + auto prev_symbol_it = scope_.symbols.find(all.identifier_->name_); + if (prev_symbol_it != scope_.symbols.end()) { + prev_symbol = prev_symbol_it->second; + } + symbol_table_[*all.identifier_] = CreateSymbol(all.identifier_->name_, true); + // Visit Where with the new symbol bound. + all.where_->Accept(*this); + // Restore the old symbol or just remove the newly bound if there was no + // symbol before. + if (prev_symbol) { + scope_.symbols[all.identifier_->name_] = *prev_symbol; + } else { + scope_.symbols.erase(all.identifier_->name_); + } + return false; +} + // Pattern and its subparts. bool SymbolGenerator::PreVisit(Pattern &pattern) { diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 843a6bedc..96ec23eee 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -44,6 +44,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { ReturnType Visit(PrimitiveLiteral &) override { return true; } bool PreVisit(Aggregation &) override; bool PostVisit(Aggregation &) override; + bool PreVisit(All &) override; // Pattern and its subparts. bool PreVisit(Pattern &) override; diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index b462711e0..25866932a 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -326,6 +326,33 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> { return function.function_(arguments, db_accessor_); } + TypedValue Visit(All &all) override { + auto list_value = all.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue::Null; + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("'ALL' expected a list, but got {}", + list_value.type()); + } + const auto &list = list_value.Value<std::vector<TypedValue>>(); + const auto &symbol = symbol_table_.at(*all.identifier_); + for (const auto &element : list) { + frame_[symbol] = element; + auto result = all.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException( + "Predicate of 'ALL' needs to evaluate to 'Boolean', but it " + "resulted in '{}'", + result.type()); + } + if (result.IsNull() || !result.Value<bool>()) { + return result; + } + } + return true; + } + private: // If the given TypedValue contains accessors, switch them to New or Old, // depending on use_new_ flag. diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 20bfeb902..fed8abf2a 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -133,16 +133,22 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { using HierarchicalTreeVisitor::PreVisit; using HierarchicalTreeVisitor::PostVisit; - using typename HierarchicalTreeVisitor::ReturnType; using HierarchicalTreeVisitor::Visit; - ReturnType Visit(Identifier &ident) override { + bool PostVisit(All &all) override { + // Remove the symbol which is bound by all, because we are only interested + // in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*all.identifier_)); + return true; + } + + bool Visit(Identifier &ident) override { symbols_.insert(symbol_table_.at(ident)); return true; } - ReturnType Visit(PrimitiveLiteral &) override { return true; } - ReturnType Visit(query::CreateIndex &) override { return true; } + bool Visit(PrimitiveLiteral &) override { return true; } + bool Visit(query::CreateIndex &) override { return true; } std::unordered_set<Symbol> symbols_; const SymbolTable &symbol_table_; @@ -280,6 +286,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } + bool PostVisit(All &all) override { + // Remove the symbol which is bound by all, because we are only interested + // in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*all.identifier_)); + debug_assert(has_aggregation_.size() >= 3U, + "Expected 3 has_aggregation_ flags for ALL arguments"); + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + bool Visit(Identifier &ident) override { const auto &symbol = symbol_table_.at(ident); if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) == diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature index 0c2112a91..a71cb7829 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature @@ -596,16 +596,6 @@ Feature: Functions | (:y) | false | true | | (:y) | false | true | - # Scenario: Keys test: - # Given an empty graph - # When executing query: - # """ - # CREATE (n{x: 123, a: null, b: 'x', d: 1}) RETURN KEYS(n) AS a - # """ - # Then the result should be (ignoring element order for lists) - # | a | - # | ['x', 'null', 'b'] | - # Scenario: Pi test: When executing query: """ @@ -614,3 +604,28 @@ Feature: Functions Then the result should be: | n | | 3.141592653589793 | + + Scenario: All test 01: + When executing query: + """ + RETURN all(x IN [1, 2, '3'] WHERE x < 2) AS a + """ + Then the result should be: + | a | + | false | + + Scenario: All test 02: + When executing query: + """ + RETURN all(x IN [1, 2, 3] WHERE x < 4) AS a + """ + Then the result should be: + | a | + | true | + + Scenario: All test 03: + When executing query: + """ + RETURN all(x IN [1, 2, '3'] WHERE x < 3) AS a + """ + Then an error should be raised diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 201b7f05e..61e5d5c78 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -1383,4 +1383,21 @@ TYPED_TEST(CypherMainVisitorTest, CreateIndex) { ASSERT_EQ(create_index->property_, ast_generator.db_accessor_->property("slavko")); } + +TYPED_TEST(CypherMainVisitorTest, ReturnAll) { + TypeParam ast_generator("RETURN all(x IN [1,2,3] WHERE x = 2)"); + auto *query = ast_generator.query_; + ASSERT_EQ(query->clauses_.size(), 1U); + auto *ret = dynamic_cast<Return *>(query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *all = dynamic_cast<All *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(all); + EXPECT_EQ(all->identifier_->name_, "x"); + auto *list_literal = dynamic_cast<ListLiteral *>(all->list_expression_); + EXPECT_TRUE(list_literal); + auto *eq = dynamic_cast<EqualOperator *>(all->where_->expression_); + EXPECT_TRUE(eq); +} + } diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index 958bd8254..9a3657b19 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -483,3 +483,7 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, // List slicing #define SLICE(list, lower_bound, upper_bound) \ storage.Create<query::ListSlicingOperator>(list, lower_bound, upper_bound) +// all(variable IN list WHERE predicate) +#define ALL(variable, list, where) \ + storage.Create<query::All>(storage.Create<query::Identifier>(variable), \ + list, where) diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index 35136edbb..ac80bcb65 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -6,14 +6,15 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "database/dbms.hpp" #include "database/graph_db_accessor.hpp" #include "database/graph_db_datatypes.hpp" -#include "database/dbms.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/opencypher/parser.hpp" #include "query/interpret/awesome_memgraph_functions.hpp" #include "query/interpret/eval.hpp" #include "query/interpret/frame.hpp" + #include "query_common.hpp" using namespace query; @@ -1007,4 +1008,38 @@ TEST(ExpressionEvaluator, FunctionContains) { EXPECT_FALSE(EvaluateFunction(kContains, {"cde", "abcdef"}).Value<bool>()); EXPECT_FALSE(EvaluateFunction(kContains, {"abcdef", "dEf"}).Value<bool>()); } + +TEST(ExpressionEvaluator, FunctionAll) { + AstTreeStorage storage; + auto *ident_x = IDENT("x"); + auto *all = + ALL("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1)))); + NoContextExpressionEvaluator eval; + const auto x_sym = eval.symbol_table.CreateSymbol("x", true); + eval.symbol_table[*all->identifier_] = x_sym; + eval.symbol_table[*ident_x] = x_sym; + auto value = all->Accept(eval.eval); + ASSERT_EQ(value.type(), TypedValue::Type::Bool); + EXPECT_FALSE(value.Value<bool>()); +} + +TEST(ExpressionEvaluator, FunctionAllNullList) { + AstTreeStorage storage; + auto *all = ALL("x", LITERAL(TypedValue::Null), WHERE(LITERAL(true))); + NoContextExpressionEvaluator eval; + const auto x_sym = eval.symbol_table.CreateSymbol("x", true); + eval.symbol_table[*all->identifier_] = x_sym; + auto value = all->Accept(eval.eval); + EXPECT_TRUE(value.IsNull()); +} + +TEST(ExpressionEvaluator, FunctionAllWhereWrongType) { + AstTreeStorage storage; + auto *all = ALL("x", LIST(LITERAL(1)), WHERE(LITERAL(2))); + NoContextExpressionEvaluator eval; + const auto x_sym = eval.symbol_table.CreateSymbol("x", true); + eval.symbol_table[*all->identifier_] = x_sym; + EXPECT_THROW(all->Accept(eval.eval), QueryRuntimeException); +} + } diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index a7001adf5..1046c1a24 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -1205,4 +1205,14 @@ TEST(TestLogicalPlanner, UnableToUsePropertyIndex) { ExpectProduce()); } +TEST(TestLogicalPlanner, ReturnSumGroupByAll) { + // Test RETURN sum([1,2,3]), all(x in [1] where x = 1) + AstTreeStorage storage; + auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3))); + auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1)))); + QUERY(RETURN(sum, AS("sum"), all, AS("all"))); + auto aggr = ExpectAggregate({sum}, {all}); + CheckPlan(storage, aggr, ExpectProduce()); +} + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 00e1ae481..06e506220 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -984,5 +984,30 @@ TEST(TestSymbolGenerator, MatchPropertySameIdentifier) { EXPECT_EQ(n, symbol_table.at(*n_prop->expression_)); } +TEST(TestSymbolGenerator, WithReturnAll) { + // Test WITH 42 AS x RETURN all(x IN [x] WHERE x = 2) AS x, x AS y + AstTreeStorage storage; + auto *with_as_x = AS("x"); + auto *list_x = IDENT("x"); + auto *where_x = IDENT("x"); + auto *all = ALL("x", LIST(list_x), WHERE(EQ(where_x, LITERAL(2)))); + auto *ret_as_x = AS("x"); + auto *ret_x = IDENT("x"); + auto query = QUERY(WITH(LITERAL(42), with_as_x), + RETURN(all, ret_as_x, ret_x, AS("y"))); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + // Symbols for `WITH .. AS x`, `ALL(x ...)`, `ALL(...) AS x` and `AS y`. + EXPECT_EQ(symbol_table.max_position(), 4); + // Check `WITH .. AS x` is the same as `[x]` and `RETURN ... x AS y` + EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*list_x)); + EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*ret_x)); + EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*all->identifier_)); + EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*ret_as_x)); + // Check `ALL(x ...)` is only equal to `WHERE x = 2` + EXPECT_EQ(symbol_table.at(*all->identifier_), symbol_table.at(*where_x)); + EXPECT_NE(symbol_table.at(*all->identifier_), symbol_table.at(*ret_as_x)); +} } // namespace