Add prototype implementation of the any function

Summary: Add any function prototype - no tests

Reviewers: mferencevic

Reviewed By: mferencevic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2785
This commit is contained in:
jseljan 2020-06-10 16:11:35 +02:00
parent 3dd393f8eb
commit 098333f735
11 changed files with 166 additions and 3 deletions

View File

@ -1028,6 +1028,47 @@ cpp<#
(:serialize (:slk))
(:clone))
;; TODO: This is pretty much copy pasted from All. Consider merging Reduce,
;; All, Any and Single into something like a higher-order function call which
;; takes a list argument and a function which is applied on list elements.
(lcp:define-class any (expression)
((identifier "Identifier *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Identifier"))
(list-expression "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(where "Where *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Where")))
(:public
#>cpp
Any() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
where_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
cpp<#)
(:protected
#>cpp
Any(Identifier *identifier, Expression *list_expression, Where *where)
: identifier_(identifier),
list_expression_(list_expression),
where_(where) {}
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:define-class parameter-lookup (expression)
((token-position :int32_t :initval -1 :scope :public
:documentation "This field contains token position of *literal* used to create ParameterLookup object. If ParameterLookup object is not created from a literal leave this value at -1."))

View File

@ -19,6 +19,7 @@ class Coalesce;
class Extract;
class All;
class Single;
class Any;
class ParameterLookup;
class CallProcedure;
class Create;
@ -79,7 +80,7 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
Aggregation, Function, Reduce, Coalesce, Extract, All, Single,
Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any,
CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind, RegexMatch>;
@ -107,8 +108,8 @@ class ExpressionVisitor
SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator,
UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce,
Extract, All, Single, ParameterLookup, Identifier, PrimitiveLiteral,
RegexMatch> {};
Extract, All, Single, Any, ParameterLookup, Identifier,
PrimitiveLiteral, RegexMatch> {};
template <class TResult>
class QueryVisitor

View File

@ -1377,6 +1377,20 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>(
storage_->Create<Single>(ident, list_expr, where));
} else if (ctx->ANY()) {
auto *ident = storage_->Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list_expr =
ctx->filterExpression()->idInColl()->expression()->accept(this);
if (!ctx->filterExpression()->where()) {
throw SyntaxException("ANY(...) requires a WHERE predicate.");
}
Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>(
storage_->Create<Any>(ident, list_expr, where));
} else if (ctx->REDUCE()) {
auto *accumulator = storage_->Create<Identifier>(
ctx->reduceExpression()->accumulator->accept(this).as<std::string>());

View File

@ -51,6 +51,7 @@ class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
void Visit(Extract &op) override;
void Visit(All &op) override;
void Visit(Single &op) override;
void Visit(Any &op) override;
void Visit(Identifier &op) override;
void Visit(PrimitiveLiteral &op) override;
void Visit(PropertyLookup &op) override;
@ -286,6 +287,11 @@ void ExpressionPrettyPrinter::Visit(Single &op) {
op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Any &op) {
PrintOperator(out_, "Any", op.identifier_, op.list_expression_,
op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Identifier &op) {
PrintOperator(out_, "Identifier", op.name_);
}

View File

@ -354,6 +354,12 @@ bool SymbolGenerator::PreVisit(Single &single) {
return false;
}
bool SymbolGenerator::PreVisit(Any &any) {
any.list_expression_->Accept(*this);
VisitWithIdentifiers(any.where_->expression_, {any.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(Reduce &reduce) {
reduce.initializer_->Accept(*this);
reduce.list_->Accept(*this);

View File

@ -58,6 +58,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PostVisit(IfOperator &) override;
bool PreVisit(All &) override;
bool PreVisit(Single &) override;
bool PreVisit(Any &) override;
bool PreVisit(Reduce &) override;
bool PreVisit(Extract &) override;

View File

@ -523,6 +523,32 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(predicate_satisfied, ctx_->memory);
}
TypedValue Visit(Any &any) override {
auto list_value = any.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("ANY expected a list, got {}.",
list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*any.identifier_);
for (const auto &element : list) {
frame_->at(symbol) = element;
auto result = any.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of ANY must evaluate to boolean, got {}.",
result.type());
}
if (result.IsNull() || result.ValueBool()) {
return result;
}
}
return TypedValue(false, ctx_->memory);
}
TypedValue Visit(ParameterLookup &param_lookup) override {
return TypedValue(
ctx_->parameters.AtTokenPosition(param_lookup.token_position_),

View File

@ -37,6 +37,13 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*any.identifier_));
return true;
}
bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols.

View File

@ -165,6 +165,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true;
}
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by any, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*any.identifier_));
CHECK(has_aggregation_.size() >= 3U)
<< "Expected 3 has_aggregation_ flags for ANY arguments";
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols.

View File

@ -639,6 +639,9 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match,
#define SINGLE(variable, list, where) \
storage.Create<query::Single>(storage.Create<query::Identifier>(variable), \
list, where)
#define ANY(variable, list, where) \
storage.Create<query::Any>(storage.Create<query::Identifier>(variable), \
list, where)
#define REDUCE(accumulator, initializer, variable, list, expr) \
storage.Create<query::Reduce>( \
storage.Create<query::Identifier>(accumulator), initializer, \

View File

@ -769,6 +769,49 @@ TEST_F(ExpressionEvaluatorTest, FunctionSingleNullList) {
EXPECT_TRUE(value.IsNull());
}
TEST_F(ExpressionEvaluatorTest, FunctionAny) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *any =
ANY("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1))));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
ident_x->MapTo(x_sym);
auto value = Eval(any);
ASSERT_TRUE(value.IsBool());
EXPECT_TRUE(value.ValueBool());
}
TEST_F(ExpressionEvaluatorTest, FunctionAny2) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *any =
ANY("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(0))));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
ident_x->MapTo(x_sym);
auto value = Eval(any);
ASSERT_TRUE(value.IsBool());
EXPECT_FALSE(value.ValueBool());
}
TEST_F(ExpressionEvaluatorTest, FunctionAnyNullList) {
AstStorage storage;
auto *any = ANY("x", LITERAL(storage::PropertyValue()), WHERE(LITERAL(true)));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
auto value = Eval(any);
EXPECT_TRUE(value.IsNull());
}
TEST_F(ExpressionEvaluatorTest, FunctionAnyWhereWrongType) {
AstStorage storage;
auto *any = ANY("x", LIST(LITERAL(1)), WHERE(LITERAL(2)));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
EXPECT_THROW(Eval(any), QueryRuntimeException);
}
TEST_F(ExpressionEvaluatorTest, FunctionReduce) {
AstStorage storage;
auto *ident_sum = IDENT("sum");