Add ALL function to openCypher

Summary:
Add All expression to Ast
Evaluate All expression
Visit All and generate symbols
Handle All when collecting context during planning

Reviewers: florijan, mislav.bradac, buda

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D587
This commit is contained in:
Teon Banek 2017-07-25 13:01:08 +02:00
parent da7bfe05b7
commit b33aae42ab
16 changed files with 264 additions and 21 deletions

View File

@ -4,6 +4,7 @@
### Major Features and Improvements
* Support for `all` function in openCypher.
* User specified transaction execution timeout.
* Support for query parameters (except for parameters in place of property maps).

View File

@ -421,6 +421,7 @@ functions.
`startsWith` | Check if the first argument starts with the second.
`endsWith` | Check if the first argument ends with the second.
`contains` | Check if the first argument has an element which is equal to the second argument.
`all` | Check if all elements of a list satisfy a predicate.<br/>The syntax is: `all(variable IN list WHERE predicate)`.
#### Parameters

View File

@ -1014,6 +1014,42 @@ class Where : public Tree {
Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {}
};
class All : public Expression {
friend class AstTreeStorage;
public:
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
where_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
All *Clone(AstTreeStorage &storage) const override {
return storage.Create<All>(identifier_->Clone(storage),
list_expression_->Clone(storage),
where_->Clone(storage));
}
Identifier *identifier_ = nullptr;
Expression *list_expression_ = nullptr;
Where *where_ = nullptr;
protected:
All(int uid, Identifier *identifier, Expression *list_expression,
Where *where)
: Expression(uid),
identifier_(identifier),
list_expression_(list_expression),
where_(where) {
debug_assert(identifier, "identifier must not be nullptr");
debug_assert(list_expression, "list_expression must not be nullptr");
debug_assert(where, "where must not be nullptr");
}
};
class Match : public Clause {
friend class AstTreeStorage;

View File

@ -13,6 +13,7 @@ class LabelsTest;
class EdgeTypeTest;
class Aggregation;
class Function;
class All;
class Create;
class Match;
class Return;
@ -63,9 +64,9 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
GreaterEqualOperator, InListOperator, ListIndexingOperator,
ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator,
ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation,
Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete,
Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
Merge, Unwind>;
Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind>;
using TreeLeafVisitor =
::utils::LeafVisitor<Identifier, PrimitiveLiteral, CreateIndex>;
@ -88,8 +89,8 @@ using TreeVisitor = ::utils::Visitor<
GreaterEqualOperator, InListOperator, ListIndexingOperator,
ListSlicingOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator,
ListLiteral, PropertyLookup, LabelsTest, EdgeTypeTest, Aggregation,
Function, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete,
Where, SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels,
Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>;
Function, All, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind, Identifier, PrimitiveLiteral, CreateIndex>;
} // namespace query

View File

@ -756,6 +756,17 @@ antlrcpp::Any CypherMainVisitor::visitAtom(CypherParser::AtomContext *ctx) {
// functionInvocation and atom producions in opencypher grammar.
return static_cast<Expression *>(
storage_.Create<Aggregation>(nullptr, Aggregation::Op::COUNT));
} else if (ctx->ALL()) {
auto *ident = storage_.Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list_expr =
ctx->filterExpression()->idInColl()->expression()->accept(this);
Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>(
storage_.Create<All>(ident, list_expr, where));
}
// TODO: Implement this. We don't support comprehensions, filtering... at
// the moment.
@ -1024,4 +1035,10 @@ antlrcpp::Any CypherMainVisitor::visitUnwind(CypherParser::UnwindContext *ctx) {
return storage_.Create<Unwind>(named_expr);
}
antlrcpp::Any CypherMainVisitor::visitFilterExpression(
CypherParser::FilterExpressionContext *) {
debug_fail("Should never be called. See documentation in hpp.");
return 0;
}
} // namespace query::frontend

View File

@ -530,6 +530,13 @@ class CypherMainVisitor : public antlropencypher::CypherBaseVisitor {
*/
antlrcpp::Any visitUnwind(CypherParser::UnwindContext *ctx) override;
/**
* Never call this. Ast generation for these expressions should be done by
* explicitly visiting the members of @c FilterExpressionContext.
*/
antlrcpp::Any visitFilterExpression(
CypherParser::FilterExpressionContext *) override;
public:
Query *query() { return query_; }
AstTreeStorage &storage() { return storage_; }

View File

@ -4,8 +4,11 @@
#include "query/frontend/semantic/symbol_generator.hpp"
#include <experimental/optional>
#include <unordered_set>
#include "utils/algorithm.hpp"
namespace query {
auto SymbolGenerator::CreateSymbol(const std::string &name, bool user_declared,
@ -265,6 +268,28 @@ bool SymbolGenerator::PostVisit(Aggregation &) {
return true;
}
bool SymbolGenerator::PreVisit(All &all) {
all.list_expression_->Accept(*this);
// Bind the new symbol after visiting the list expression. Keep the old symbol
// so it can be restored.
std::experimental::optional<Symbol> prev_symbol;
auto prev_symbol_it = scope_.symbols.find(all.identifier_->name_);
if (prev_symbol_it != scope_.symbols.end()) {
prev_symbol = prev_symbol_it->second;
}
symbol_table_[*all.identifier_] = CreateSymbol(all.identifier_->name_, true);
// Visit Where with the new symbol bound.
all.where_->Accept(*this);
// Restore the old symbol or just remove the newly bound if there was no
// symbol before.
if (prev_symbol) {
scope_.symbols[all.identifier_->name_] = *prev_symbol;
} else {
scope_.symbols.erase(all.identifier_->name_);
}
return false;
}
// Pattern and its subparts.
bool SymbolGenerator::PreVisit(Pattern &pattern) {

View File

@ -44,6 +44,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
ReturnType Visit(PrimitiveLiteral &) override { return true; }
bool PreVisit(Aggregation &) override;
bool PostVisit(Aggregation &) override;
bool PreVisit(All &) override;
// Pattern and its subparts.
bool PreVisit(Pattern &) override;

View File

@ -326,6 +326,33 @@ class ExpressionEvaluator : public TreeVisitor<TypedValue> {
return function.function_(arguments, db_accessor_);
}
TypedValue Visit(All &all) override {
auto list_value = all.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue::Null;
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("'ALL' expected a list, but got {}",
list_value.type());
}
const auto &list = list_value.Value<std::vector<TypedValue>>();
const auto &symbol = symbol_table_.at(*all.identifier_);
for (const auto &element : list) {
frame_[symbol] = element;
auto result = all.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of 'ALL' needs to evaluate to 'Boolean', but it "
"resulted in '{}'",
result.type());
}
if (result.IsNull() || !result.Value<bool>()) {
return result;
}
}
return true;
}
private:
// If the given TypedValue contains accessors, switch them to New or Old,
// depending on use_new_ flag.

View File

@ -133,16 +133,22 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor {
using HierarchicalTreeVisitor::PreVisit;
using HierarchicalTreeVisitor::PostVisit;
using typename HierarchicalTreeVisitor::ReturnType;
using HierarchicalTreeVisitor::Visit;
ReturnType Visit(Identifier &ident) override {
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*all.identifier_));
return true;
}
bool Visit(Identifier &ident) override {
symbols_.insert(symbol_table_.at(ident));
return true;
}
ReturnType Visit(PrimitiveLiteral &) override { return true; }
ReturnType Visit(query::CreateIndex &) override { return true; }
bool Visit(PrimitiveLiteral &) override { return true; }
bool Visit(query::CreateIndex &) override { return true; }
std::unordered_set<Symbol> symbols_;
const SymbolTable &symbol_table_;
@ -280,6 +286,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(All &all) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*all.identifier_));
debug_assert(has_aggregation_.size() >= 3U,
"Expected 3 has_aggregation_ flags for ALL arguments");
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool Visit(Identifier &ident) override {
const auto &symbol = symbol_table_.at(ident);
if (std::find(output_symbols_.begin(), output_symbols_.end(), symbol) ==

View File

@ -596,16 +596,6 @@ Feature: Functions
| (:y) | false | true |
| (:y) | false | true |
# Scenario: Keys test:
# Given an empty graph
# When executing query:
# """
# CREATE (n{x: 123, a: null, b: 'x', d: 1}) RETURN KEYS(n) AS a
# """
# Then the result should be (ignoring element order for lists)
# | a |
# | ['x', 'null', 'b'] |
#
Scenario: Pi test:
When executing query:
"""
@ -614,3 +604,28 @@ Feature: Functions
Then the result should be:
| n |
| 3.141592653589793 |
Scenario: All test 01:
When executing query:
"""
RETURN all(x IN [1, 2, '3'] WHERE x < 2) AS a
"""
Then the result should be:
| a |
| false |
Scenario: All test 02:
When executing query:
"""
RETURN all(x IN [1, 2, 3] WHERE x < 4) AS a
"""
Then the result should be:
| a |
| true |
Scenario: All test 03:
When executing query:
"""
RETURN all(x IN [1, 2, '3'] WHERE x < 3) AS a
"""
Then an error should be raised

View File

@ -1383,4 +1383,21 @@ TYPED_TEST(CypherMainVisitorTest, CreateIndex) {
ASSERT_EQ(create_index->property_,
ast_generator.db_accessor_->property("slavko"));
}
TYPED_TEST(CypherMainVisitorTest, ReturnAll) {
TypeParam ast_generator("RETURN all(x IN [1,2,3] WHERE x = 2)");
auto *query = ast_generator.query_;
ASSERT_EQ(query->clauses_.size(), 1U);
auto *ret = dynamic_cast<Return *>(query->clauses_[0]);
ASSERT_TRUE(ret);
ASSERT_EQ(ret->body_.named_expressions.size(), 1U);
auto *all = dynamic_cast<All *>(ret->body_.named_expressions[0]->expression_);
ASSERT_TRUE(all);
EXPECT_EQ(all->identifier_->name_, "x");
auto *list_literal = dynamic_cast<ListLiteral *>(all->list_expression_);
EXPECT_TRUE(list_literal);
auto *eq = dynamic_cast<EqualOperator *>(all->where_->expression_);
EXPECT_TRUE(eq);
}
}

View File

@ -483,3 +483,7 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
// List slicing
#define SLICE(list, lower_bound, upper_bound) \
storage.Create<query::ListSlicingOperator>(list, lower_bound, upper_bound)
// all(variable IN list WHERE predicate)
#define ALL(variable, list, where) \
storage.Create<query::All>(storage.Create<query::Identifier>(variable), \
list, where)

View File

@ -6,14 +6,15 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "database/dbms.hpp"
#include "database/graph_db_accessor.hpp"
#include "database/graph_db_datatypes.hpp"
#include "database/dbms.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/opencypher/parser.hpp"
#include "query/interpret/awesome_memgraph_functions.hpp"
#include "query/interpret/eval.hpp"
#include "query/interpret/frame.hpp"
#include "query_common.hpp"
using namespace query;
@ -1007,4 +1008,38 @@ TEST(ExpressionEvaluator, FunctionContains) {
EXPECT_FALSE(EvaluateFunction(kContains, {"cde", "abcdef"}).Value<bool>());
EXPECT_FALSE(EvaluateFunction(kContains, {"abcdef", "dEf"}).Value<bool>());
}
TEST(ExpressionEvaluator, FunctionAll) {
AstTreeStorage storage;
auto *ident_x = IDENT("x");
auto *all =
ALL("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1))));
NoContextExpressionEvaluator eval;
const auto x_sym = eval.symbol_table.CreateSymbol("x", true);
eval.symbol_table[*all->identifier_] = x_sym;
eval.symbol_table[*ident_x] = x_sym;
auto value = all->Accept(eval.eval);
ASSERT_EQ(value.type(), TypedValue::Type::Bool);
EXPECT_FALSE(value.Value<bool>());
}
TEST(ExpressionEvaluator, FunctionAllNullList) {
AstTreeStorage storage;
auto *all = ALL("x", LITERAL(TypedValue::Null), WHERE(LITERAL(true)));
NoContextExpressionEvaluator eval;
const auto x_sym = eval.symbol_table.CreateSymbol("x", true);
eval.symbol_table[*all->identifier_] = x_sym;
auto value = all->Accept(eval.eval);
EXPECT_TRUE(value.IsNull());
}
TEST(ExpressionEvaluator, FunctionAllWhereWrongType) {
AstTreeStorage storage;
auto *all = ALL("x", LIST(LITERAL(1)), WHERE(LITERAL(2)));
NoContextExpressionEvaluator eval;
const auto x_sym = eval.symbol_table.CreateSymbol("x", true);
eval.symbol_table[*all->identifier_] = x_sym;
EXPECT_THROW(all->Accept(eval.eval), QueryRuntimeException);
}
}

View File

@ -1205,4 +1205,14 @@ TEST(TestLogicalPlanner, UnableToUsePropertyIndex) {
ExpectProduce());
}
TEST(TestLogicalPlanner, ReturnSumGroupByAll) {
// Test RETURN sum([1,2,3]), all(x in [1] where x = 1)
AstTreeStorage storage;
auto sum = SUM(LIST(LITERAL(1), LITERAL(2), LITERAL(3)));
auto *all = ALL("x", LIST(LITERAL(1)), WHERE(EQ(IDENT("x"), LITERAL(1))));
QUERY(RETURN(sum, AS("sum"), all, AS("all")));
auto aggr = ExpectAggregate({sum}, {all});
CheckPlan(storage, aggr, ExpectProduce());
}
} // namespace

View File

@ -984,5 +984,30 @@ TEST(TestSymbolGenerator, MatchPropertySameIdentifier) {
EXPECT_EQ(n, symbol_table.at(*n_prop->expression_));
}
TEST(TestSymbolGenerator, WithReturnAll) {
// Test WITH 42 AS x RETURN all(x IN [x] WHERE x = 2) AS x, x AS y
AstTreeStorage storage;
auto *with_as_x = AS("x");
auto *list_x = IDENT("x");
auto *where_x = IDENT("x");
auto *all = ALL("x", LIST(list_x), WHERE(EQ(where_x, LITERAL(2))));
auto *ret_as_x = AS("x");
auto *ret_x = IDENT("x");
auto query = QUERY(WITH(LITERAL(42), with_as_x),
RETURN(all, ret_as_x, ret_x, AS("y")));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for `WITH .. AS x`, `ALL(x ...)`, `ALL(...) AS x` and `AS y`.
EXPECT_EQ(symbol_table.max_position(), 4);
// Check `WITH .. AS x` is the same as `[x]` and `RETURN ... x AS y`
EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*list_x));
EXPECT_EQ(symbol_table.at(*with_as_x), symbol_table.at(*ret_x));
EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*all->identifier_));
EXPECT_NE(symbol_table.at(*with_as_x), symbol_table.at(*ret_as_x));
// Check `ALL(x ...)` is only equal to `WHERE x = 2`
EXPECT_EQ(symbol_table.at(*all->identifier_), symbol_table.at(*where_x));
EXPECT_NE(symbol_table.at(*all->identifier_), symbol_table.at(*ret_as_x));
}
} // namespace