From 4326847ab3f70e946d07b77b46e2b5b967ec63c8 Mon Sep 17 00:00:00 2001 From: Teon Banek <teon.banek@memgraph.io> Date: Mon, 12 Feb 2018 16:13:45 +0100 Subject: [PATCH] Add SINGLE function to openCypher Reviewers: florijan, msantl, buda Reviewed By: msantl Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1193 --- docs/user_technical/open-cypher.md | 1 + src/query/frontend/ast/ast.cpp | 1 + src/query/frontend/ast/ast.hpp | 63 +++++++++++++++++++ src/query/frontend/ast/ast_visitor.hpp | 11 ++-- .../frontend/ast/cypher_main_visitor.cpp | 11 ++++ .../frontend/semantic/symbol_generator.cpp | 6 ++ .../frontend/semantic/symbol_generator.hpp | 1 + src/query/interpret/eval.hpp | 34 ++++++++++ src/query/plan/preprocess.hpp | 8 ++- src/query/plan/rule_based_planner.cpp | 16 ++++- .../memgraph_V1/features/functions.feature | 25 ++++++++ tests/unit/cypher_main_visitor.cpp | 19 ++++++ tests/unit/query_common.hpp | 3 + tests/unit/query_expression_evaluator.cpp | 38 +++++++++++ tests/unit/query_semantic.cpp | 23 +++++++ 15 files changed, 253 insertions(+), 7 deletions(-) diff --git a/docs/user_technical/open-cypher.md b/docs/user_technical/open-cypher.md index 5d8c87c7a..31d112710 100644 --- a/docs/user_technical/open-cypher.md +++ b/docs/user_technical/open-cypher.md @@ -550,6 +550,7 @@ functions. `endsWith` | Check if the first argument ends with the second. `contains` | Check if the first argument has an element which is equal to the second argument. `all` | Check if all elements of a list satisfy a predicate.<br/>The syntax is: `all(variable IN list WHERE predicate)`.<br/> NOTE: Whenever possible, use Memgraph's lambda functions when [matching](#filtering-variable-length-paths) instead. + `single` | Check if only one element of a list satisfies a predicate.<br/>The syntax is: `single(variable IN list WHERE predicate)`. `reduce` | Accumulate list elements into a single result by applying an expression. The syntax is:<br/>`reduce(accumulator = initial_value, variable IN list | expression)`. `assert` | Raises an exception reported to the client if the given argument is not `true`. `counter` | Generates integers that are guaranteed to be unique on the database level, for the given counter name. diff --git a/src/query/frontend/ast/ast.cpp b/src/query/frontend/ast/ast.cpp index 144c7aabf..0359c494c 100644 --- a/src/query/frontend/ast/ast.cpp +++ b/src/query/frontend/ast/ast.cpp @@ -78,6 +78,7 @@ BOOST_CLASS_EXPORT_IMPLEMENT(query::Aggregation); BOOST_CLASS_EXPORT_IMPLEMENT(query::Function); BOOST_CLASS_EXPORT_IMPLEMENT(query::Reduce); BOOST_CLASS_EXPORT_IMPLEMENT(query::All); +BOOST_CLASS_EXPORT_IMPLEMENT(query::Single); BOOST_CLASS_EXPORT_IMPLEMENT(query::ParameterLookup); BOOST_CLASS_EXPORT_IMPLEMENT(query::Create); BOOST_CLASS_EXPORT_IMPLEMENT(query::Match); diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 7e567bece..049edec2f 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1523,6 +1523,67 @@ class All : public Expression { const unsigned int); }; +// TODO: This is pretty much copy pasted from All. Consider merging Reduce, All, +// Any and Single into something like a higher-order function call which takes a +// list argument and a function which is applied on list elements. +class Single : public Expression { + friend class AstTreeStorage; + + public: + DEFVISITABLE(TreeVisitor<TypedValue>); + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + identifier_->Accept(visitor) && list_expression_->Accept(visitor) && + where_->Accept(visitor); + } + return visitor.PostVisit(*this); + } + + Single *Clone(AstTreeStorage &storage) const override { + return storage.Create<Single>(identifier_->Clone(storage), + list_expression_->Clone(storage), + where_->Clone(storage)); + } + + // None of these should be nullptr after construction. + Identifier *identifier_ = nullptr; + Expression *list_expression_ = nullptr; + Where *where_ = nullptr; + + protected: + Single(int uid, Identifier *identifier, Expression *list_expression, + Where *where) + : Expression(uid), + identifier_(identifier), + list_expression_(list_expression), + where_(where) {} + + private: + friend class boost::serialization::access; + + BOOST_SERIALIZATION_SPLIT_MEMBER(); + + template <class TArchive> + void save(TArchive &ar, const unsigned int) const { + ar << boost::serialization::base_object<Expression>(*this); + SavePointer(ar, identifier_); + SavePointer(ar, list_expression_); + SavePointer(ar, where_); + } + + template <class TArchive> + void load(TArchive &ar, const unsigned int) { + ar >> boost::serialization::base_object<Expression>(*this); + LoadPointer(ar, identifier_); + LoadPointer(ar, list_expression_); + LoadPointer(ar, where_); + } + + template <class TArchive> + friend void boost::serialization::load_construct_data(TArchive &, Single *, + const unsigned int); +}; + class ParameterLookup : public Expression { friend class AstTreeStorage; @@ -2962,6 +3023,7 @@ LOAD_AND_CONSTRUCT(query::Aggregation, 0, nullptr, nullptr, LOAD_AND_CONSTRUCT(query::Reduce, 0, nullptr, nullptr, nullptr, nullptr, nullptr); LOAD_AND_CONSTRUCT(query::All, 0, nullptr, nullptr, nullptr); +LOAD_AND_CONSTRUCT(query::Single, 0, nullptr, nullptr, nullptr); LOAD_AND_CONSTRUCT(query::ParameterLookup, 0); LOAD_AND_CONSTRUCT(query::NamedExpression, 0); LOAD_AND_CONSTRUCT(query::NodeAtom, 0); @@ -3022,6 +3084,7 @@ BOOST_CLASS_EXPORT_KEY(query::Aggregation); BOOST_CLASS_EXPORT_KEY(query::Function); BOOST_CLASS_EXPORT_KEY(query::Reduce); BOOST_CLASS_EXPORT_KEY(query::All); +BOOST_CLASS_EXPORT_KEY(query::Single); BOOST_CLASS_EXPORT_KEY(query::ParameterLookup); BOOST_CLASS_EXPORT_KEY(query::Create); BOOST_CLASS_EXPORT_KEY(query::Match); diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index cd9b9657a..e0de05911 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -16,6 +16,7 @@ class Aggregation; class Function; class Reduce; class All; +class Single; class ParameterLookup; class Create; class Match; @@ -68,9 +69,9 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< GreaterEqualOperator, InListOperator, ListMapIndexingOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, - Aggregation, Function, Reduce, All, Create, Match, Return, With, Pattern, - NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, - RemoveProperty, RemoveLabels, Merge, Unwind>; + Aggregation, Function, Reduce, All, Single, Create, Match, Return, With, + Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, + SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind>; using TreeLeafVisitor = ::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup, CreateIndex>; @@ -93,8 +94,8 @@ using TreeVisitor = ::utils::Visitor< LessEqualOperator, GreaterEqualOperator, InListOperator, ListMapIndexingOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, - LabelsTest, Aggregation, Function, Reduce, All, ParameterLookup, Create, - Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, + LabelsTest, Aggregation, Function, Reduce, All, Single, ParameterLookup, + Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index aba2efdc3..24042794a 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -901,6 +901,17 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) { Where *where = ctx->filterExpression()->where()->accept(this); return static_cast<Expression *>( storage_.Create<All>(ident, list_expr, where)); + } else if (ctx->SINGLE()) { + auto *ident = storage_.Create<Identifier>(ctx->filterExpression() + ->idInColl() + ->variable() + ->accept(this) + .as<std::string>()); + Expression *list_expr = + ctx->filterExpression()->idInColl()->expression()->accept(this); + Where *where = ctx->filterExpression()->where()->accept(this); + return static_cast<Expression *>( + storage_.Create<Single>(ident, list_expr, where)); } else if (ctx->REDUCE()) { auto *accumulator = storage_.Create<Identifier>( ctx->reduceExpression()->accumulator->accept(this).as<std::string>()); diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 2d5fa47fa..e1d5392df 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -333,6 +333,12 @@ bool SymbolGenerator::PreVisit(All &all) { return false; } +bool SymbolGenerator::PreVisit(Single &single) { + single.list_expression_->Accept(*this); + VisitWithIdentifiers(*single.where_, {single.identifier_}); + return false; +} + bool SymbolGenerator::PreVisit(Reduce &reduce) { reduce.initializer_->Accept(*this); reduce.list_->Accept(*this); diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 0b68415f0..178fbafb0 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -57,6 +57,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor { bool PreVisit(IfOperator &) override; bool PostVisit(IfOperator &) override; bool PreVisit(All &) override; + bool PreVisit(Single &) override; bool PreVisit(Reduce &) override; // Pattern and its subparts. diff --git a/src/query/interpret/eval.hpp b/src/query/interpret/eval.hpp index 44bb40d4d..22904e5b3 100644 --- a/src/query/interpret/eval.hpp +++ b/src/query/interpret/eval.hpp @@ -402,6 +402,40 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> { return true; } + TypedValue Visit(Single &single) override { + auto list_value = single.list_expression_->Accept(*this); + if (list_value.IsNull()) { + return TypedValue::Null; + } + if (list_value.type() != TypedValue::Type::List) { + throw QueryRuntimeException("'SINGLE' expected a list, but got {}", + list_value.type()); + } + const auto &list = list_value.Value<std::vector<TypedValue>>(); + const auto &symbol = symbol_table_.at(*single.identifier_); + bool predicate_satisfied = false; + for (const auto &element : list) { + frame_[symbol] = element; + auto result = single.where_->expression_->Accept(*this); + if (!result.IsNull() && result.type() != TypedValue::Type::Bool) { + throw QueryRuntimeException( + "Predicate of 'SINGLE' needs to evaluate to 'Boolean', but it " + "resulted in '{}'", + result.type()); + } + if (result.IsNull() || !result.Value<bool>()) { + continue; + } + // Return false if more than one element satisfies the predicate. + if (predicate_satisfied) { + return false; + } else { + predicate_satisfied = true; + } + } + return predicate_satisfied; + } + TypedValue Visit(ParameterLookup ¶m_lookup) override { return parameters_.AtTokenPosition(param_lookup.token_position_); } diff --git a/src/query/plan/preprocess.hpp b/src/query/plan/preprocess.hpp index 3f7e9b24f..c21c7c6b8 100644 --- a/src/query/plan/preprocess.hpp +++ b/src/query/plan/preprocess.hpp @@ -30,6 +30,13 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { return true; } + bool PostVisit(Single &single) override { + // Remove the symbol which is bound by single, because we are only + // interested in free (unbound) symbols. + symbols_.erase(symbol_table_.at(*single.identifier_)); + return true; + } + bool PostVisit(Reduce &reduce) override { // Remove the symbols bound by reduce, because we are only interested // in free (unbound) symbols. @@ -38,7 +45,6 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor { return true; } - bool Visit(Identifier &ident) override { symbols_.insert(symbol_table_.at(ident)); return true; diff --git a/src/query/plan/rule_based_planner.cpp b/src/query/plan/rule_based_planner.cpp index 5888491d5..c2798b867 100644 --- a/src/query/plan/rule_based_planner.cpp +++ b/src/query/plan/rule_based_planner.cpp @@ -210,6 +210,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } + bool PostVisit(Single &single) override { + // Remove the symbol which is bound by single, because we are only + // interested in free (unbound) symbols. + used_symbols_.erase(symbol_table_.at(*single.identifier_)); + DCHECK(has_aggregation_.size() >= 3U) + << "Expected 3 has_aggregation_ flags for SINGLE arguments"; + bool has_aggr = false; + for (int i = 0; i < 3; ++i) { + has_aggr = has_aggr || has_aggregation_.back(); + has_aggregation_.pop_back(); + } + has_aggregation_.emplace_back(has_aggr); + return true; + } + bool PostVisit(Reduce &reduce) override { // Remove the symbols bound by reduce, because we are only interested // in free (unbound) symbols. @@ -226,7 +241,6 @@ class ReturnBodyContext : public HierarchicalTreeVisitor { return true; } - bool Visit(Identifier &ident) override { const auto &symbol = symbol_table_.at(ident); if (!utils::Contains(output_symbols_, symbol)) { diff --git a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature index c0c7b0631..1004c0f88 100644 --- a/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature +++ b/tests/qa/tck_engine/tests/memgraph_V1/features/functions.feature @@ -678,6 +678,31 @@ Feature: Functions """ Then an error should be raised + Scenario: Single test 01: + When executing query: + """ + RETURN single(x IN [1, 2, '3'] WHERE x < 4) AS a + """ + Then the result should be: + | a | + | false | + + Scenario: Single test 02: + When executing query: + """ + RETURN single(x IN [1, 2, 3] WHERE x = 1) AS a + """ + Then the result should be: + | a | + | true | + + Scenario: Single test 03: + When executing query: + """ + RETURN single(x IN [1, 2, '3'] WHERE x > 2) AS a + """ + Then an error should be raised + Scenario: Reduce test 01: When executing query: """ diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 244e3ed6d..d47219b68 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -1563,6 +1563,25 @@ TYPED_TEST(CypherMainVisitorTest, ReturnAll) { EXPECT_TRUE(eq); } +TYPED_TEST(CypherMainVisitorTest, ReturnSingle) { + TypeParam ast_generator("RETURN single(x IN [1,2,3] WHERE x = 2)"); + auto *query = ast_generator.query_; + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *ret = dynamic_cast<Return *>(single_query->clauses_[0]); + ASSERT_TRUE(ret); + ASSERT_EQ(ret->body_.named_expressions.size(), 1U); + auto *single = + dynamic_cast<Single *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(single); + EXPECT_EQ(single->identifier_->name_, "x"); + auto *list_literal = dynamic_cast<ListLiteral *>(single->list_expression_); + EXPECT_TRUE(list_literal); + auto *eq = dynamic_cast<EqualOperator *>(single->where_->expression_); + EXPECT_TRUE(eq); +} + TYPED_TEST(CypherMainVisitorTest, ReturnReduce) { TypeParam ast_generator("RETURN reduce(sum = 0, x IN [1,2,3] | sum + x)"); auto *query = ast_generator.query_; diff --git a/tests/unit/query_common.hpp b/tests/unit/query_common.hpp index d5292d465..1cafb9e76 100644 --- a/tests/unit/query_common.hpp +++ b/tests/unit/query_common.hpp @@ -600,6 +600,9 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match, #define ALL(variable, list, where) \ storage.Create<query::All>(storage.Create<query::Identifier>(variable), \ list, where) +#define SINGLE(variable, list, where) \ + storage.Create<query::Single>(storage.Create<query::Identifier>(variable), \ + list, where) #define REDUCE(accumulator, initializer, variable, list, expr) \ storage.Create<query::Reduce>( \ storage.Create<query::Identifier>(accumulator), initializer, \ diff --git a/tests/unit/query_expression_evaluator.cpp b/tests/unit/query_expression_evaluator.cpp index b88214dab..f1f6f0bdd 100644 --- a/tests/unit/query_expression_evaluator.cpp +++ b/tests/unit/query_expression_evaluator.cpp @@ -1221,6 +1221,44 @@ TEST(ExpressionEvaluator, FunctionAllWhereWrongType) { EXPECT_THROW(all->Accept(eval.eval), QueryRuntimeException); } +TEST(ExpressionEvaluator, FunctionSingle) { + AstTreeStorage storage; + auto *ident_x = IDENT("x"); + auto *single = + SINGLE("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1)))); + NoContextExpressionEvaluator eval; + const auto x_sym = eval.symbol_table.CreateSymbol("x", true); + eval.symbol_table[*single->identifier_] = x_sym; + eval.symbol_table[*ident_x] = x_sym; + auto value = single->Accept(eval.eval); + ASSERT_EQ(value.type(), TypedValue::Type::Bool); + EXPECT_TRUE(value.Value<bool>()); +} + +TEST(ExpressionEvaluator, FunctionSingle2) { + AstTreeStorage storage; + auto *ident_x = IDENT("x"); + auto *single = SINGLE("x", LIST(LITERAL(1), LITERAL(2)), + WHERE(GREATER(ident_x, LITERAL(0)))); + NoContextExpressionEvaluator eval; + const auto x_sym = eval.symbol_table.CreateSymbol("x", true); + eval.symbol_table[*single->identifier_] = x_sym; + eval.symbol_table[*ident_x] = x_sym; + auto value = single->Accept(eval.eval); + ASSERT_EQ(value.type(), TypedValue::Type::Bool); + EXPECT_FALSE(value.Value<bool>()); +} + +TEST(ExpressionEvaluator, FunctionSingleNullList) { + AstTreeStorage storage; + auto *single = SINGLE("x", LITERAL(TypedValue::Null), WHERE(LITERAL(true))); + NoContextExpressionEvaluator eval; + const auto x_sym = eval.symbol_table.CreateSymbol("x", true); + eval.symbol_table[*single->identifier_] = x_sym; + auto value = single->Accept(eval.eval); + EXPECT_TRUE(value.IsNull()); +} + TEST(ExpressionEvaluator, FunctionReduce) { AstTreeStorage storage; auto *ident_sum = IDENT("sum"); diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 09839c216..790e720b1 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -795,6 +795,29 @@ TEST_F(TestSymbolGenerator, WithReturnAll) { EXPECT_NE(symbol_table.at(*all->identifier_), symbol_table.at(*ret_as_x)); } +TEST_F(TestSymbolGenerator, WithReturnSingle) { + // Test WITH 42 AS x RETURN single(x IN [x] WHERE x = 2) AS x, x AS y + auto *with_as_x = AS("x"); + auto *list_x = IDENT("x"); + auto *where_x = IDENT("x"); + auto *single = SINGLE("x", LIST(list_x), WHERE(EQ(where_x, LITERAL(2)))); + auto *ret_as_x = AS("x"); + auto *ret_x = IDENT("x"); + auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(42), with_as_x), + RETURN(single, ret_as_x, ret_x, AS("y")))); + query->Accept(symbol_generator); + // Symbols for `WITH .. AS x`, `SINGLE(x ...)`, `SINGLE(...) AS x` and `AS y`. + EXPECT_EQ(symbol_table.max_position(), 4); + // Check `WITH .. AS x` is the same as `[x]` and `RETURN ... x AS y` + EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*list_x)); + EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*ret_x)); + EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*single->identifier_)); + EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*ret_as_x)); + // Check `SINGLE(x ...)` is only equal to `WHERE x = 2` + EXPECT_EQ(symbol_table.at(*single->identifier_), symbol_table.at(*where_x)); + EXPECT_NE(symbol_table.at(*single->identifier_), symbol_table.at(*ret_as_x)); +} + TEST_F(TestSymbolGenerator, WithReturnReduce) { // Test WITH 42 AS x RETURN reduce(y = 0, x IN [x] y + x) AS x, x AS y auto *with_as_x = AS("x");