Check if aggregation is used in right clause

Reviewers: florijan, mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D274
This commit is contained in:
Teon Banek 2017-04-13 12:41:07 +02:00
parent d9e02d624d
commit 15d5328957
3 changed files with 42 additions and 32 deletions

View File

@ -31,15 +31,24 @@ auto SymbolGenerator::GetOrCreateSymbol(const std::string &name,
void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; } void SymbolGenerator::Visit(Create &create) { scope_.in_create = true; }
void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; } void SymbolGenerator::PostVisit(Create &create) { scope_.in_create = false; }
void SymbolGenerator::Visit(Return &ret) {
scope_.in_return = true;
}
void SymbolGenerator::PostVisit(Return &ret) { void SymbolGenerator::PostVisit(Return &ret) {
for (auto &named_expr : ret.named_expressions_) { for (auto &named_expr : ret.named_expressions_) {
// Named expressions establish bindings for expressions which come after // Named expressions establish bindings for expressions which come after
// return, but not for the expressions contained inside. // return, but not for the expressions contained inside.
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_); symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
} }
scope_.in_return = false;
} }
void SymbolGenerator::SetWithSymbols(With &with) { bool SymbolGenerator::PreVisit(With &with) {
scope_.in_with = true;
for (auto &expr : with.named_expressions_) {
expr->Accept(*this);
}
scope_.in_with = false;
// WITH clause removes declarations of all the previous variables and declares // WITH clause removes declarations of all the previous variables and declares
// only those established through named expressions. New declarations must not // only those established through named expressions. New declarations must not
// be visible inside named expressions themselves. // be visible inside named expressions themselves.
@ -47,26 +56,8 @@ void SymbolGenerator::SetWithSymbols(With &with) {
for (auto &named_expr : with.named_expressions_) { for (auto &named_expr : with.named_expressions_) {
symbol_table_[*named_expr] = CreateSymbol(named_expr->name_); symbol_table_[*named_expr] = CreateSymbol(named_expr->name_);
} }
} if (with.where_) with.where_->Accept(*this);
return false; // We handled the traversal ourselves.
void SymbolGenerator::Visit(With &with) {
scope_.with = &with;
}
void SymbolGenerator::Visit(Where &where) {
if (scope_.with) {
// New symbols must be visible in WHERE clause, so this must be done here
// and not in PostVisit(With&).
SetWithSymbols(*scope_.with);
}
}
void SymbolGenerator::PostVisit(With &with) {
if (!with.where_) {
// This wasn't done when visiting Where, so do it here.
SetWithSymbols(with);
}
scope_.with = nullptr;
} }
// Expressions // Expressions
@ -104,12 +95,20 @@ void SymbolGenerator::Visit(Identifier &ident) {
} }
void SymbolGenerator::Visit(Aggregation &aggr) { void SymbolGenerator::Visit(Aggregation &aggr) {
// Create a virtual symbol for aggregation result. // Check if the aggregation can be used in this context. This check should
symbol_table_[aggr] = symbol_table_.CreateSymbol(""); // probably move to a separate phase, which checks if the query is well
// formed.
if (!scope_.in_return && !scope_.in_with) {
throw SemanticException(
"Aggregation functions are only allowed in WITH and RETURN");
}
if (scope_.in_aggregation) { if (scope_.in_aggregation) {
throw SemanticException( throw SemanticException(
"Using aggregate functions inside aggregate functions is not allowed"); "Using aggregation functions inside aggregation functions is not "
"allowed");
} }
// Create a virtual symbol for aggregation result.
symbol_table_[aggr] = symbol_table_.CreateSymbol("");
scope_.in_aggregation = true; scope_.in_aggregation = true;
} }

View File

@ -20,16 +20,16 @@ class SymbolGenerator : public TreeVisitorBase {
public: public:
SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {} SymbolGenerator(SymbolTable &symbol_table) : symbol_table_(symbol_table) {}
using TreeVisitorBase::PreVisit;
using TreeVisitorBase::Visit; using TreeVisitorBase::Visit;
using TreeVisitorBase::PostVisit; using TreeVisitorBase::PostVisit;
// Clauses // Clauses
void Visit(Create &) override; void Visit(Create &) override;
void PostVisit(Create &) override; void PostVisit(Create &) override;
void Visit(Return &) override;
void PostVisit(Return &) override; void PostVisit(Return &) override;
void Visit(With &) override; bool PreVisit(With &) override;
void PostVisit(With &) override;
void Visit(Where &) override;
// Expressions // Expressions
void Visit(Identifier &) override; void Visit(Identifier &) override;
@ -59,8 +59,8 @@ class SymbolGenerator : public TreeVisitorBase {
bool in_edge_atom{false}; bool in_edge_atom{false};
bool in_property_map{false}; bool in_property_map{false};
bool in_aggregation{false}; bool in_aggregation{false};
// Pointer to With clause if we are inside it, otherwise nullptr. bool in_return{false};
With *with{nullptr}; bool in_with{false};
std::map<std::string, Symbol> symbols; std::map<std::string, Symbol> symbols;
}; };
@ -76,9 +76,6 @@ class SymbolGenerator : public TreeVisitorBase {
auto GetOrCreateSymbol(const std::string &name, auto GetOrCreateSymbol(const std::string &name,
Symbol::Type type = Symbol::Type::Any); Symbol::Type type = Symbol::Type::Any);
// Clear old symbol bindings and establish new from WITH clause.
void SetWithSymbols(With &with);
SymbolTable &symbol_table_; SymbolTable &symbol_table_;
Scope scope_; Scope scope_;
}; };

View File

@ -446,6 +446,20 @@ TEST(TestSymbolGenerator, NestedAggregation) {
EXPECT_THROW(query->Accept(symbol_generator), SemanticException); EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
} }
TEST(TestSymbolGenerator, WrongAggregationContext) {
// Test MATCH (n) WITH n.prop AS prop WHERE SUM(prop) < 42
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
WITH(PROPERTY_LOOKUP("n", prop), AS("prop")),
WHERE(LESS(SUM(IDENT("prop")), LITERAL(42))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
TEST(TestSymbolGenerator, MatchPropCreateNodeProp) { TEST(TestSymbolGenerator, MatchPropCreateNodeProp) {
// Test MATCH (n) CREATE (m {prop: n.prop}) // Test MATCH (n) CREATE (m {prop: n.prop})
Dbms dbms; Dbms dbms;