Check if aggregation is used in right clause
Reviewers: florijan, mislav.bradac Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D274
This commit is contained in:
parent
d9e02d624d
commit
15d5328957
@ -31,15 +31,24 @@ auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
|
|||||||
void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; }
|
void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; }
|
||||||
void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
|
void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
|
||||||
|
|
||||||
|
void SymbolGenerator::Visit(Return &ret) {
|
||||||
|
scope_.in_return = true;
|
||||||
|
}
|
||||||
void SymbolGenerator::PostVisit(Return &ret) {
|
void SymbolGenerator::PostVisit(Return &ret) {
|
||||||
for (auto &named_expr : ret.named_expressions_) {
|
for (auto &named_expr : ret.named_expressions_) {
|
||||||
// Named expressions establish bindings for expressions which come after
|
// Named expressions establish bindings for expressions which come after
|
||||||
// return, but not for the expressions contained inside.
|
// return, but not for the expressions contained inside.
|
||||||
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
|
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
|
||||||
}
|
}
|
||||||
|
scope_.in_return = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SymbolGenerator::SetWithSymbols(With &with) {
|
bool SymbolGenerator::PreVisit(With &with) {
|
||||||
|
scope_.in_with = true;
|
||||||
|
for (auto &expr : with.named_expressions_) {
|
||||||
|
expr->Accept(*this);
|
||||||
|
}
|
||||||
|
scope_.in_with = false;
|
||||||
// WITH clause removes declarations of all the previous variables and declares
|
// WITH clause removes declarations of all the previous variables and declares
|
||||||
// only those established through named expressions. New declarations must not
|
// only those established through named expressions. New declarations must not
|
||||||
// be visible inside named expressions themselves.
|
// be visible inside named expressions themselves.
|
||||||
@ -47,26 +56,8 @@ void SymbolGenerator::SetWithSymbols(With &with) {
|
|||||||
for (auto &named_expr : with.named_expressions_) {
|
for (auto &named_expr : with.named_expressions_) {
|
||||||
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
|
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
|
||||||
}
|
}
|
||||||
}
|
if (with.where_) with.where_->Accept(*this);
|
||||||
|
return false; // We handled the traversal ourselves.
|
||||||
void SymbolGenerator::Visit(With &with) {
|
|
||||||
scope_.with = &with;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SymbolGenerator::Visit(Where &where) {
|
|
||||||
if (scope_.with) {
|
|
||||||
// New symbols must be visible in WHERE clause, so this must be done here
|
|
||||||
// and not in PostVisit(With&).
|
|
||||||
SetWithSymbols(*scope_.with);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void SymbolGenerator::PostVisit(With &with) {
|
|
||||||
if (!with.where_) {
|
|
||||||
// This wasn't done when visiting Where, so do it here.
|
|
||||||
SetWithSymbols(with);
|
|
||||||
}
|
|
||||||
scope_.with = nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expressions
|
// Expressions
|
||||||
@ -104,12 +95,20 @@ void SymbolGenerator::Visit(Identifier &ident) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SymbolGenerator::Visit(Aggregation &aggr) {
|
void SymbolGenerator::Visit(Aggregation &aggr) {
|
||||||
// Create a virtual symbol for aggregation result.
|
// Check if the aggregation can be used in this context. This check should
|
||||||
symbol_table_[aggr] = symbol_table_.CreateSymbol("");
|
// probably move to a separate phase, which checks if the query is well
|
||||||
|
// formed.
|
||||||
|
if (!scope_.in_return && !scope_.in_with) {
|
||||||
|
throw SemanticException(
|
||||||
|
"Aggregation functions are only allowed in WITH and RETURN");
|
||||||
|
}
|
||||||
if (scope_.in_aggregation) {
|
if (scope_.in_aggregation) {
|
||||||
throw SemanticException(
|
throw SemanticException(
|
||||||
"Using aggregate functions inside aggregate functions is not allowed");
|
"Using aggregation functions inside aggregation functions is not "
|
||||||
|
"allowed");
|
||||||
}
|
}
|
||||||
|
// Create a virtual symbol for aggregation result.
|
||||||
|
symbol_table_[aggr] = symbol_table_.CreateSymbol("");
|
||||||
scope_.in_aggregation = true;
|
scope_.in_aggregation = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,16 +20,16 @@ class SymbolGenerator : public TreeVisitorBase {
|
|||||||
public:
|
public:
|
||||||
SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
|
SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
|
||||||
|
|
||||||
|
using TreeVisitorBase::PreVisit;
|
||||||
using TreeVisitorBase::Visit;
|
using TreeVisitorBase::Visit;
|
||||||
using TreeVisitorBase::PostVisit;
|
using TreeVisitorBase::PostVisit;
|
||||||
|
|
||||||
// Clauses
|
// Clauses
|
||||||
void Visit(Create &) override;
|
void Visit(Create &) override;
|
||||||
void PostVisit(Create &) override;
|
void PostVisit(Create &) override;
|
||||||
|
void Visit(Return &) override;
|
||||||
void PostVisit(Return &) override;
|
void PostVisit(Return &) override;
|
||||||
void Visit(With &) override;
|
bool PreVisit(With &) override;
|
||||||
void PostVisit(With &) override;
|
|
||||||
void Visit(Where &) override;
|
|
||||||
|
|
||||||
// Expressions
|
// Expressions
|
||||||
void Visit(Identifier &) override;
|
void Visit(Identifier &) override;
|
||||||
@ -59,8 +59,8 @@ class SymbolGenerator : public TreeVisitorBase {
|
|||||||
bool in_edge_atom{false};
|
bool in_edge_atom{false};
|
||||||
bool in_property_map{false};
|
bool in_property_map{false};
|
||||||
bool in_aggregation{false};
|
bool in_aggregation{false};
|
||||||
// Pointer to With clause if we are inside it, otherwise nullptr.
|
bool in_return{false};
|
||||||
With *with{nullptr};
|
bool in_with{false};
|
||||||
std::map<std::string, Symbol> symbols;
|
std::map<std::string, Symbol> symbols;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -76,9 +76,6 @@ class SymbolGenerator : public TreeVisitorBase {
|
|||||||
auto GetOrCreateSymbol(const std::string &name,
|
auto GetOrCreateSymbol(const std::string &name,
|
||||||
Symbol::Type type = Symbol::Type::Any);
|
Symbol::Type type = Symbol::Type::Any);
|
||||||
|
|
||||||
// Clear old symbol bindings and establish new from WITH clause.
|
|
||||||
void SetWithSymbols(With &with);
|
|
||||||
|
|
||||||
SymbolTable &symbol_table_;
|
SymbolTable &symbol_table_;
|
||||||
Scope scope_;
|
Scope scope_;
|
||||||
};
|
};
|
||||||
|
@ -446,6 +446,20 @@ TEST(TestSymbolGenerator, NestedAggregation) {
|
|||||||
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TestSymbolGenerator, WrongAggregationContext) {
|
||||||
|
// Test MATCH (n) WITH n.prop AS prop WHERE SUM(prop) < 42
|
||||||
|
Dbms dbms;
|
||||||
|
auto dba = dbms.active();
|
||||||
|
auto prop = dba->property("prop");
|
||||||
|
AstTreeStorage storage;
|
||||||
|
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||||
|
WITH(PROPERTY_LOOKUP("n", prop), AS("prop")),
|
||||||
|
WHERE(LESS(SUM(IDENT("prop")), LITERAL(42))));
|
||||||
|
SymbolTable symbol_table;
|
||||||
|
SymbolGenerator symbol_generator(symbol_table);
|
||||||
|
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(TestSymbolGenerator, MatchPropCreateNodeProp) {
|
TEST(TestSymbolGenerator, MatchPropCreateNodeProp) {
|
||||||
// Test MATCH (n) CREATE (m {prop: n.prop})
|
// Test MATCH (n) CREATE (m {prop: n.prop})
|
||||||
Dbms dbms;
|
Dbms dbms;
|
||||||
|
Loading…
Reference in New Issue
Block a user