Add ALL function to openCypher
Summary: Add All expression to Ast Evaluate All expression Visit All and generate symbols Handle All when collecting context during planning Reviewers: florijan, mislav.bradac, buda Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D587
This commit is contained in:
parent
da7bfe05b7
commit
b33aae42ab
@ -4,6 +4,7 @@
|
||||
|
||||
### Major Features and Improvements
|
||||
|
||||
* Support for `all` function in openCypher.
|
||||
* User specified transaction execution timeout.
|
||||
* Support for query parameters (except for parameters in place of property maps).
|
||||
|
||||
|
@ -421,6 +421,7 @@ functions.
|
||||
`startsWith` | Check if the first argument starts with the second.
|
||||
`endsWith` | Check if the first argument ends with the second.
|
||||
`contains` | Check if the first argument has an element which is equal to the second argument.
|
||||
`all` | Check if all elements of a list satisfy a predicate.<br/>The syntax is: `all(variable IN list WHERE predicate)`.
|
||||
|
||||
#### Parameters
|
||||
|
||||
|
@ -1014,6 +1014,42 @@ class Where : public Tree {
|
||||
Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {}
|
||||
};
|
||||
|
||||
class All : public Expression {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
public:
|
||||
DEFVISITABLE(TreeVisitor<TypedValue>);
|
||||
bool Accept(HierarchicalTreeVisitor &visitor) override {
|
||||
if (visitor.PreVisit(*this)) {
|
||||
identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
|
||||
where_->Accept(visitor);
|
||||
}
|
||||
return visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
All *Clone(AstTreeStorage &storage) const override {
|
||||
return storage.Create<All>(identifier_->Clone(storage),
|
||||
list_expression_->Clone(storage),
|
||||
where_->Clone(storage));
|
||||
}
|
||||
|
||||
Identifier *identifier_ = nullptr;
|
||||
Expression *list_expression_ = nullptr;
|
||||
Where *where_ = nullptr;
|
||||
|
||||
protected:
|
||||
All(int uid, Identifier *identifier, Expression *list_expression,
|
||||
Where *where)
|
||||
: Expression(uid),
|
||||
identifier_(identifier),
|
||||
list_expression_(list_expression),
|
||||
where_(where) {
|
||||
debug_assert(identifier, "identifier must not be nullptr");
|
||||
debug_assert(list_expression, "list_expression must not be nullptr");
|
||||
debug_assert(where, "where must not be nullptr");
|
||||
}
|
||||
};
|
||||
|
||||
class Match : public Clause {
|
||||
friend class AstTreeStorage;
|
||||
|
||||
|
@ -13,6 +13,7 @@ class LabelsTest;
|
||||
class EdgeTypeTest;
|
||||
class Aggregation;
|
||||
class Function;
|
||||
class All;
|
||||
class Create;
|
||||
class Match;
|
||||
class Return;
|
||||
@ -63,9 +64,9 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
|
||||
GreaterEqualOperator, InListOperator, ListIndexingOperator,
|
||||
ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator,
|
||||
ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation,
|
||||
Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete,
|
||||
Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
|
||||
Merge, Unwind>;
|
||||
Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
|
||||
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
|
||||
RemoveLabels, Merge, Unwind>;
|
||||
|
||||
using TreeLeafVisitor =
|
||||
::utils::LeafVisitor<Identifier, PrimitiveLiteral, CreateIndex>;
|
||||
@ -88,8 +89,8 @@ using TreeVisitor = ::utils::Visitor<
|
||||
GreaterEqualOperator, InListOperator, ListIndexingOperator,
|
||||
ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator,
|
||||
ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation,
|
||||
Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete,
|
||||
Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
|
||||
Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>;
|
||||
Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
|
||||
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
|
||||
RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>;
|
||||
|
||||
} // namespace query
|
||||
|
@ -756,6 +756,17 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
|
||||
// functionInvocation and atom producions in opencypher grammar.
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
|
||||
} else if (ctx->ALL()) {
|
||||
auto *ident = storage_.Create<Identifier>(ctx->filterExpression()
|
||||
->idInColl()
|
||||
->variable()
|
||||
->accept(this)
|
||||
.as<std::string>());
|
||||
Expression *list_expr =
|
||||
ctx->filterExpression()->idInColl()->expression()->accept(this);
|
||||
Where *where = ctx->filterExpression()->where()->accept(this);
|
||||
return static_cast<Expression *>(
|
||||
storage_.Create<All>(ident, list_expr, where));
|
||||
}
|
||||
// TODO: Implement this. We don't support comprehensions, filtering... at
|
||||
// the moment.
|
||||
@ -1024,4 +1035,10 @@ antlrcpp::Any CypherMainVisitor::visitUnwind(CypherParser::UnwindContext *ctx) {
|
||||
return storage_.Create<Unwind>(named_expr);
|
||||
}
|
||||
|
||||
antlrcpp::Any CypherMainVisitor::visitFilterExpression(
|
||||
CypherParser::FilterExpressionContext *) {
|
||||
debug_fail("Should never be called. See documentation in hpp.");
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace query::frontend
|
||||
|
@ -530,6 +530,13 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
|
||||
*/
|
||||
antlrcpp::Any visitUnwind(CypherParser::UnwindContext *ctx) override;
|
||||
|
||||
/**
|
||||
* Never call this. Ast generation for these expressions should be done by
|
||||
* explicitly visiting the members of @c FilterExpressionContext.
|
||||
*/
|
||||
antlrcpp::Any visitFilterExpression(
|
||||
CypherParser::FilterExpressionContext *) override;
|
||||
|
||||
public:
|
||||
Query *query() { return query_; }
|
||||
AstTreeStorage &storage() { return storage_; }
|
||||
|
@ -4,8 +4,11 @@
|
||||
|
||||
#include "query/frontend/semantic/symbol_generator.hpp"
|
||||
|
||||
#include <experimental/optional>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "utils/algorithm.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared,
|
||||
@ -265,6 +268,28 @@ bool SymbolGenerator::PostVisit(Aggregation &) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SymbolGenerator::PreVisit(All &all) {
|
||||
all.list_expression_->Accept(*this);
|
||||
// Bind the new symbol after visiting the list expression. Keep the old symbol
|
||||
// so it can be restored.
|
||||
std::experimental::optional<Symbol> prev_symbol;
|
||||
auto prev_symbol_it = scope_.symbols.find(all.identifier_->name_);
|
||||
if (prev_symbol_it != scope_.symbols.end()) {
|
||||
prev_symbol = prev_symbol_it->second;
|
||||
}
|
||||
symbol_table_[*all.identifier_] = CreateSymbol(all.identifier_->name_, true);
|
||||
// Visit Where with the new symbol bound.
|
||||
all.where_->Accept(*this);
|
||||
// Restore the old symbol or just remove the newly bound if there was no
|
||||
// symbol before.
|
||||
if (prev_symbol) {
|
||||
scope_.symbols[all.identifier_->name_] = *prev_symbol;
|
||||
} else {
|
||||
scope_.symbols.erase(all.identifier_->name_);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Pattern and its subparts.
|
||||
|
||||
bool SymbolGenerator::PreVisit(Pattern &pattern) {
|
||||
|
@ -44,6 +44,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
|
||||
ReturnType Visit(PrimitiveLiteral &) override { return true; }
|
||||
bool PreVisit(Aggregation &) override;
|
||||
bool PostVisit(Aggregation &) override;
|
||||
bool PreVisit(All &) override;
|
||||
|
||||
// Pattern and its subparts.
|
||||
bool PreVisit(Pattern &) override;
|
||||
|
@ -326,6 +326,33 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> {
|
||||
return function.function_(arguments, db_accessor_);
|
||||
}
|
||||
|
||||
TypedValue Visit(All &all) override {
|
||||
auto list_value = all.list_expression_->Accept(*this);
|
||||
if (list_value.IsNull()) {
|
||||
return TypedValue::Null;
|
||||
}
|
||||
if (list_value.type() != TypedValue::Type::List) {
|
||||
throw QueryRuntimeException("'ALL' expected a list, but got {}",
|
||||
list_value.type());
|
||||
}
|
||||
const auto &list = list_value.Value<std::vector<TypedValue>>();
|
||||
const auto &symbol = symbol_table_.at(*all.identifier_);
|
||||
for (const auto &element : list) {
|
||||
frame_[symbol] = element;
|
||||
auto result = all.where_->expression_->Accept(*this);
|
||||
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
|
||||
throw QueryRuntimeException(
|
||||
"Predicate of 'ALL' needs to evaluate to 'Boolean', but it "
|
||||
"resulted in '{}'",
|
||||
result.type());
|
||||
}
|
||||
if (result.IsNull() || !result.Value<bool>()) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
// If the given TypedValue contains accessors, switch them to New or Old,
|
||||
// depending on use_new_ flag.
|
||||
|
@ -133,16 +133,22 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor {
|
||||
|
||||
using HierarchicalTreeVisitor::PreVisit;
|
||||
using HierarchicalTreeVisitor::PostVisit;
|
||||
using typename HierarchicalTreeVisitor::ReturnType;
|
||||
using HierarchicalTreeVisitor::Visit;
|
||||
|
||||
ReturnType Visit(Identifier &ident) override {
|
||||
bool PostVisit(All &all) override {
|
||||
// Remove the symbol which is bound by all, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
symbols_.erase(symbol_table_.at(*all.identifier_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(Identifier &ident) override {
|
||||
symbols_.insert(symbol_table_.at(ident));
|
||||
return true;
|
||||
}
|
||||
|
||||
ReturnType Visit(PrimitiveLiteral &) override { return true; }
|
||||
ReturnType Visit(query::CreateIndex &) override { return true; }
|
||||
bool Visit(PrimitiveLiteral &) override { return true; }
|
||||
bool Visit(query::CreateIndex &) override { return true; }
|
||||
|
||||
std::unordered_set<Symbol> symbols_;
|
||||
const SymbolTable &symbol_table_;
|
||||
@ -280,6 +286,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PostVisit(All &all) override {
|
||||
// Remove the symbol which is bound by all, because we are only interested
|
||||
// in free (unbound) symbols.
|
||||
used_symbols_.erase(symbol_table_.at(*all.identifier_));
|
||||
debug_assert(has_aggregation_.size() >= 3U,
|
||||
"Expected 3 has_aggregation_ flags for ALL arguments");
|
||||
bool has_aggr = false;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
has_aggr = has_aggr || has_aggregation_.back();
|
||||
has_aggregation_.pop_back();
|
||||
}
|
||||
has_aggregation_.emplace_back(has_aggr);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Visit(Identifier &ident) override {
|
||||
const auto &symbol = symbol_table_.at(ident);
|
||||
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==
|
||||
|
@ -596,16 +596,6 @@ Feature: Functions
|
||||
| (:y) | false | true |
|
||||
| (:y) | false | true |
|
||||
|
||||
# Scenario: Keys test:
|
||||
# Given an empty graph
|
||||
# When executing query:
|
||||
# """
|
||||
# CREATE (n{x: 123, a: null, b: 'x', d: 1}) RETURN KEYS(n) AS a
|
||||
# """
|
||||
# Then the result should be (ignoring element order for lists)
|
||||
# | a |
|
||||
# | ['x', 'null', 'b'] |
|
||||
#
|
||||
Scenario: Pi test:
|
||||
When executing query:
|
||||
"""
|
||||
@ -614,3 +604,28 @@ Feature: Functions
|
||||
Then the result should be:
|
||||
| n |
|
||||
| 3.141592653589793 |
|
||||
|
||||
Scenario: All test 01:
|
||||
When executing query:
|
||||
"""
|
||||
RETURN all(x IN [1, 2, '3'] WHERE x < 2) AS a
|
||||
"""
|
||||
Then the result should be:
|
||||
| a |
|
||||
| false |
|
||||
|
||||
Scenario: All test 02:
|
||||
When executing query:
|
||||
"""
|
||||
RETURN all(x IN [1, 2, 3] WHERE x < 4) AS a
|
||||
"""
|
||||
Then the result should be:
|
||||
| a |
|
||||
| true |
|
||||
|
||||
Scenario: All test 03:
|
||||
When executing query:
|
||||
"""
|
||||
RETURN all(x IN [1, 2, '3'] WHERE x < 3) AS a
|
||||
"""
|
||||
Then an error should be raised
|
||||
|
@ -1383,4 +1383,21 @@ TYPED_TEST(CypherMainVisitorTest, CreateIndex) {
|
||||
ASSERT_EQ(create_index->property_,
|
||||
ast_generator.db_accessor_->property("slavko"));
|
||||
}
|
||||
|
||||
TYPED_TEST(CypherMainVisitorTest, ReturnAll) {
|
||||
TypeParam ast_generator("RETURN all(x IN [1,2,3] WHERE x = 2)");
|
||||
auto *query = ast_generator.query_;
|
||||
ASSERT_EQ(query->clauses_.size(), 1U);
|
||||
auto *ret = dynamic_cast<Return *>(query->clauses_[0]);
|
||||
ASSERT_TRUE(ret);
|
||||
ASSERT_EQ(ret->body_.named_expressions.size(), 1U);
|
||||
auto *all = dynamic_cast<All *>(ret->body_.named_expressions[0]->expression_);
|
||||
ASSERT_TRUE(all);
|
||||
EXPECT_EQ(all->identifier_->name_, "x");
|
||||
auto *list_literal = dynamic_cast<ListLiteral *>(all->list_expression_);
|
||||
EXPECT_TRUE(list_literal);
|
||||
auto *eq = dynamic_cast<EqualOperator *>(all->where_->expression_);
|
||||
EXPECT_TRUE(eq);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -483,3 +483,7 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
|
||||
// List slicing
|
||||
#define SLICE(list, lower_bound, upper_bound) \
|
||||
storage.Create<query::ListSlicingOperator>(list, lower_bound, upper_bound)
|
||||
// all(variable IN list WHERE predicate)
|
||||
#define ALL(variable, list, where) \
|
||||
storage.Create<query::All>(storage.Create<query::Identifier>(variable), \
|
||||
list, where)
|
||||
|
@ -6,14 +6,15 @@
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "database/dbms.hpp"
|
||||
#include "database/graph_db_accessor.hpp"
|
||||
#include "database/graph_db_datatypes.hpp"
|
||||
#include "database/dbms.hpp"
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/opencypher/parser.hpp"
|
||||
#include "query/interpret/awesome_memgraph_functions.hpp"
|
||||
#include "query/interpret/eval.hpp"
|
||||
#include "query/interpret/frame.hpp"
|
||||
|
||||
#include "query_common.hpp"
|
||||
|
||||
using namespace query;
|
||||
@ -1007,4 +1008,38 @@ TEST(ExpressionEvaluator, FunctionContains) {
|
||||
EXPECT_FALSE(EvaluateFunction(kContains, {"cde", "abcdef"}).Value<bool>());
|
||||
EXPECT_FALSE(EvaluateFunction(kContains, {"abcdef", "dEf"}).Value<bool>());
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, FunctionAll) {
|
||||
AstTreeStorage storage;
|
||||
auto *ident_x = IDENT("x");
|
||||
auto *all =
|
||||
ALL("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1))));
|
||||
NoContextExpressionEvaluator eval;
|
||||
const auto x_sym = eval.symbol_table.CreateSymbol("x", true);
|
||||
eval.symbol_table[*all->identifier_] = x_sym;
|
||||
eval.symbol_table[*ident_x] = x_sym;
|
||||
auto value = all->Accept(eval.eval);
|
||||
ASSERT_EQ(value.type(), TypedValue::Type::Bool);
|
||||
EXPECT_FALSE(value.Value<bool>());
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, FunctionAllNullList) {
|
||||
AstTreeStorage storage;
|
||||
auto *all = ALL("x", LITERAL(TypedValue::Null), WHERE(LITERAL(true)));
|
||||
NoContextExpressionEvaluator eval;
|
||||
const auto x_sym = eval.symbol_table.CreateSymbol("x", true);
|
||||
eval.symbol_table[*all->identifier_] = x_sym;
|
||||
auto value = all->Accept(eval.eval);
|
||||
EXPECT_TRUE(value.IsNull());
|
||||
}
|
||||
|
||||
TEST(ExpressionEvaluator, FunctionAllWhereWrongType) {
|
||||
AstTreeStorage storage;
|
||||
auto *all = ALL("x", LIST(LITERAL(1)), WHERE(LITERAL(2)));
|
||||
NoContextExpressionEvaluator eval;
|
||||
const auto x_sym = eval.symbol_table.CreateSymbol("x", true);
|
||||
eval.symbol_table[*all->identifier_] = x_sym;
|
||||
EXPECT_THROW(all->Accept(eval.eval), QueryRuntimeException);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1205,4 +1205,14 @@ TEST(TestLogicalPlanner, UnableToUsePropertyIndex) {
|
||||
ExpectProduce());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, ReturnSumGroupByAll) {
|
||||
// Test RETURN sum([1,2,3]), all(x in [1] where x = 1)
|
||||
AstTreeStorage storage;
|
||||
auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3)));
|
||||
auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1))));
|
||||
QUERY(RETURN(sum, AS("sum"), all, AS("all")));
|
||||
auto aggr = ExpectAggregate({sum}, {all});
|
||||
CheckPlan(storage, aggr, ExpectProduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -984,5 +984,30 @@ TEST(TestSymbolGenerator, MatchPropertySameIdentifier) {
|
||||
EXPECT_EQ(n, symbol_table.at(*n_prop->expression_));
|
||||
}
|
||||
|
||||
TEST(TestSymbolGenerator, WithReturnAll) {
|
||||
// Test WITH 42 AS x RETURN all(x IN [x] WHERE x = 2) AS x, x AS y
|
||||
AstTreeStorage storage;
|
||||
auto *with_as_x = AS("x");
|
||||
auto *list_x = IDENT("x");
|
||||
auto *where_x = IDENT("x");
|
||||
auto *all = ALL("x", LIST(list_x), WHERE(EQ(where_x, LITERAL(2))));
|
||||
auto *ret_as_x = AS("x");
|
||||
auto *ret_x = IDENT("x");
|
||||
auto query = QUERY(WITH(LITERAL(42), with_as_x),
|
||||
RETURN(all, ret_as_x, ret_x, AS("y")));
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
// Symbols for `WITH .. AS x`, `ALL(x ...)`, `ALL(...) AS x` and `AS y`.
|
||||
EXPECT_EQ(symbol_table.max_position(), 4);
|
||||
// Check `WITH .. AS x` is the same as `[x]` and `RETURN ... x AS y`
|
||||
EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*list_x));
|
||||
EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*ret_x));
|
||||
EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*all->identifier_));
|
||||
EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*ret_as_x));
|
||||
// Check `ALL(x ...)` is only equal to `WHERE x = 2`
|
||||
EXPECT_EQ(symbol_table.at(*all->identifier_), symbol_table.at(*where_x));
|
||||
EXPECT_NE(symbol_table.at(*all->identifier_), symbol_table.at(*ret_as_x));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user