Add prototype implementation of the any function

Summary: Add any function prototype - no tests

Reviewers: mferencevic

Reviewed By: mferencevic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2785
This commit is contained in:
jseljan 2020-06-10 16:11:35 +02:00
parent 3dd393f8eb
commit 098333f735
11 changed files with 166 additions and 3 deletions

View File

@ -1028,6 +1028,47 @@ cpp<#
(:serialize (:slk)) (:serialize (:slk))
(:clone)) (:clone))
;; TODO: This is pretty much copy pasted from All. Consider merging Reduce,
;; All, Any and Single into something like a higher-order function call which
;; takes a list argument and a function which is applied on list elements.
(lcp:define-class any (expression)
((identifier "Identifier *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Identifier"))
(list-expression "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(where "Where *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Where")))
(:public
#>cpp
Any() = default;
DEFVISITABLE(ExpressionVisitor<TypedValue>);
DEFVISITABLE(ExpressionVisitor<void>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
where_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
cpp<#)
(:protected
#>cpp
Any(Identifier *identifier, Expression *list_expression, Where *where)
: identifier_(identifier),
list_expression_(list_expression),
where_(where) {}
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:define-class parameter-lookup (expression) (lcp:define-class parameter-lookup (expression)
((token-position :int32_t :initval -1 :scope :public ((token-position :int32_t :initval -1 :scope :public
:documentation "This field contains token position of *literal* used to create ParameterLookup object. If ParameterLookup object is not created from a literal leave this value at -1.")) :documentation "This field contains token position of *literal* used to create ParameterLookup object. If ParameterLookup object is not created from a literal leave this value at -1."))

View File

@ -19,6 +19,7 @@ class Coalesce;
class Extract; class Extract;
class All; class All;
class Single; class Single;
class Any;
class ParameterLookup; class ParameterLookup;
class CallProcedure; class CallProcedure;
class Create; class Create;
@ -79,7 +80,7 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
GreaterEqualOperator, InListOperator, SubscriptOperator, GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Any,
CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind, RegexMatch>; RemoveLabels, Merge, Unwind, RegexMatch>;
@ -107,8 +108,8 @@ class ExpressionVisitor
SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator,
UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral,
PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce, PropertyLookup, LabelsTest, Aggregation, Function, Reduce, Coalesce,
Extract, All, Single, ParameterLookup, Identifier, PrimitiveLiteral, Extract, All, Single, Any, ParameterLookup, Identifier,
RegexMatch> {}; PrimitiveLiteral, RegexMatch> {};
template <class TResult> template <class TResult>
class QueryVisitor class QueryVisitor

View File

@ -1377,6 +1377,20 @@ antlrcpp::Any CypherMainVisitor::visitAtom(MemgraphCypher::AtomContext *ctx) {
Where *where = ctx->filterExpression()->where()->accept(this); Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>( return static_cast<Expression *>(
storage_->Create<Single>(ident, list_expr, where)); storage_->Create<Single>(ident, list_expr, where));
} else if (ctx->ANY()) {
auto *ident = storage_->Create<Identifier>(ctx->filterExpression()
->idInColl()
->variable()
->accept(this)
.as<std::string>());
Expression *list_expr =
ctx->filterExpression()->idInColl()->expression()->accept(this);
if (!ctx->filterExpression()->where()) {
throw SyntaxException("ANY(...) requires a WHERE predicate.");
}
Where *where = ctx->filterExpression()->where()->accept(this);
return static_cast<Expression *>(
storage_->Create<Any>(ident, list_expr, where));
} else if (ctx->REDUCE()) { } else if (ctx->REDUCE()) {
auto *accumulator = storage_->Create<Identifier>( auto *accumulator = storage_->Create<Identifier>(
ctx->reduceExpression()->accumulator->accept(this).as<std::string>()); ctx->reduceExpression()->accumulator->accept(this).as<std::string>());

View File

@ -51,6 +51,7 @@ class ExpressionPrettyPrinter : public ExpressionVisitor<void> {
void Visit(Extract &op) override; void Visit(Extract &op) override;
void Visit(All &op) override; void Visit(All &op) override;
void Visit(Single &op) override; void Visit(Single &op) override;
void Visit(Any &op) override;
void Visit(Identifier &op) override; void Visit(Identifier &op) override;
void Visit(PrimitiveLiteral &op) override; void Visit(PrimitiveLiteral &op) override;
void Visit(PropertyLookup &op) override; void Visit(PropertyLookup &op) override;
@ -286,6 +287,11 @@ void ExpressionPrettyPrinter::Visit(Single &op) {
op.where_->expression_); op.where_->expression_);
} }
void ExpressionPrettyPrinter::Visit(Any &op) {
PrintOperator(out_, "Any", op.identifier_, op.list_expression_,
op.where_->expression_);
}
void ExpressionPrettyPrinter::Visit(Identifier &op) { void ExpressionPrettyPrinter::Visit(Identifier &op) {
PrintOperator(out_, "Identifier", op.name_); PrintOperator(out_, "Identifier", op.name_);
} }

View File

@ -354,6 +354,12 @@ bool SymbolGenerator::PreVisit(Single &single) {
return false; return false;
} }
bool SymbolGenerator::PreVisit(Any &any) {
any.list_expression_->Accept(*this);
VisitWithIdentifiers(any.where_->expression_, {any.identifier_});
return false;
}
bool SymbolGenerator::PreVisit(Reduce &reduce) { bool SymbolGenerator::PreVisit(Reduce &reduce) {
reduce.initializer_->Accept(*this); reduce.initializer_->Accept(*this);
reduce.list_->Accept(*this); reduce.list_->Accept(*this);

View File

@ -58,6 +58,7 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PostVisit(IfOperator &) override; bool PostVisit(IfOperator &) override;
bool PreVisit(All &) override; bool PreVisit(All &) override;
bool PreVisit(Single &) override; bool PreVisit(Single &) override;
bool PreVisit(Any &) override;
bool PreVisit(Reduce &) override; bool PreVisit(Reduce &) override;
bool PreVisit(Extract &) override; bool PreVisit(Extract &) override;

View File

@ -523,6 +523,32 @@ class ExpressionEvaluator : public ExpressionVisitor<TypedValue> {
return TypedValue(predicate_satisfied, ctx_->memory); return TypedValue(predicate_satisfied, ctx_->memory);
} }
TypedValue Visit(Any &any) override {
auto list_value = any.list_expression_->Accept(*this);
if (list_value.IsNull()) {
return TypedValue(ctx_->memory);
}
if (list_value.type() != TypedValue::Type::List) {
throw QueryRuntimeException("ANY expected a list, got {}.",
list_value.type());
}
const auto &list = list_value.ValueList();
const auto &symbol = symbol_table_->at(*any.identifier_);
for (const auto &element : list) {
frame_->at(symbol) = element;
auto result = any.where_->expression_->Accept(*this);
if (!result.IsNull() && result.type() != TypedValue::Type::Bool) {
throw QueryRuntimeException(
"Predicate of ANY must evaluate to boolean, got {}.",
result.type());
}
if (result.IsNull() || result.ValueBool()) {
return result;
}
}
return TypedValue(false, ctx_->memory);
}
TypedValue Visit(ParameterLookup &param_lookup) override { TypedValue Visit(ParameterLookup &param_lookup) override {
return TypedValue( return TypedValue(
ctx_->parameters.AtTokenPosition(param_lookup.token_position_), ctx_->parameters.AtTokenPosition(param_lookup.token_position_),

View File

@ -37,6 +37,13 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor {
return true; return true;
} }
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by all, because we are only interested
// in free (unbound) symbols.
symbols_.erase(symbol_table_.at(*any.identifier_));
return true;
}
bool PostVisit(Reduce &reduce) override { bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested // Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols. // in free (unbound) symbols.

View File

@ -165,6 +165,21 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
return true; return true;
} }
bool PostVisit(Any &any) override {
// Remove the symbol which is bound by any, because we are only interested
// in free (unbound) symbols.
used_symbols_.erase(symbol_table_.at(*any.identifier_));
CHECK(has_aggregation_.size() >= 3U)
<< "Expected 3 has_aggregation_ flags for ANY arguments";
bool has_aggr = false;
for (int i = 0; i < 3; ++i) {
has_aggr = has_aggr || has_aggregation_.back();
has_aggregation_.pop_back();
}
has_aggregation_.emplace_back(has_aggr);
return true;
}
bool PostVisit(Reduce &reduce) override { bool PostVisit(Reduce &reduce) override {
// Remove the symbols bound by reduce, because we are only interested // Remove the symbols bound by reduce, because we are only interested
// in free (unbound) symbols. // in free (unbound) symbols.

View File

@ -639,6 +639,9 @@ auto GetMerge(AstStorage &storage, Pattern *pattern, OnMatch on_match,
#define SINGLE(variable, list, where) \ #define SINGLE(variable, list, where) \
storage.Create<query::Single>(storage.Create<query::Identifier>(variable), \ storage.Create<query::Single>(storage.Create<query::Identifier>(variable), \
list, where) list, where)
#define ANY(variable, list, where) \
storage.Create<query::Any>(storage.Create<query::Identifier>(variable), \
list, where)
#define REDUCE(accumulator, initializer, variable, list, expr) \ #define REDUCE(accumulator, initializer, variable, list, expr) \
storage.Create<query::Reduce>( \ storage.Create<query::Reduce>( \
storage.Create<query::Identifier>(accumulator), initializer, \ storage.Create<query::Identifier>(accumulator), initializer, \

View File

@ -769,6 +769,49 @@ TEST_F(ExpressionEvaluatorTest, FunctionSingleNullList) {
EXPECT_TRUE(value.IsNull()); EXPECT_TRUE(value.IsNull());
} }
TEST_F(ExpressionEvaluatorTest, FunctionAny) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *any =
ANY("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(1))));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
ident_x->MapTo(x_sym);
auto value = Eval(any);
ASSERT_TRUE(value.IsBool());
EXPECT_TRUE(value.ValueBool());
}
TEST_F(ExpressionEvaluatorTest, FunctionAny2) {
AstStorage storage;
auto *ident_x = IDENT("x");
auto *any =
ANY("x", LIST(LITERAL(1), LITERAL(2)), WHERE(EQ(ident_x, LITERAL(0))));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
ident_x->MapTo(x_sym);
auto value = Eval(any);
ASSERT_TRUE(value.IsBool());
EXPECT_FALSE(value.ValueBool());
}
TEST_F(ExpressionEvaluatorTest, FunctionAnyNullList) {
AstStorage storage;
auto *any = ANY("x", LITERAL(storage::PropertyValue()), WHERE(LITERAL(true)));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
auto value = Eval(any);
EXPECT_TRUE(value.IsNull());
}
TEST_F(ExpressionEvaluatorTest, FunctionAnyWhereWrongType) {
AstStorage storage;
auto *any = ANY("x", LIST(LITERAL(1)), WHERE(LITERAL(2)));
const auto x_sym = symbol_table.CreateSymbol("x", true);
any->identifier_->MapTo(x_sym);
EXPECT_THROW(Eval(any), QueryRuntimeException);
}
TEST_F(ExpressionEvaluatorTest, FunctionReduce) { TEST_F(ExpressionEvaluatorTest, FunctionReduce) {
AstStorage storage; AstStorage storage;
auto *ident_sum = IDENT("sum"); auto *ident_sum = IDENT("sum");