Generate symbols for BFS

Reviewers: florijan, mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D608
This commit is contained in:
Teon Banek 2017-07-30 12:19:44 +02:00
parent 0aa399bd91
commit 4b72118894
3 changed files with 129 additions and 17 deletions

View File

@ -270,23 +270,7 @@ bool SymbolGenerator::PostVisit(Aggregation &) {
bool SymbolGenerator::PreVisit(All &all) { bool SymbolGenerator::PreVisit(All &all) {
all.list_expression_->Accept(*this); all.list_expression_->Accept(*this);
// Bind the new symbol after visiting the list expression. Keep the old symbol VisitWithIdentifiers(*all.where_, {all.identifier_});
// so it can be restored.
std::experimental::optional<Symbol> prev_symbol;
auto prev_symbol_it = scope_.symbols.find(all.identifier_->name_);
if (prev_symbol_it != scope_.symbols.end()) {
prev_symbol = prev_symbol_it->second;
}
symbol_table_[*all.identifier_] = CreateSymbol(all.identifier_->name_, true);
// Visit Where with the new symbol bound.
all.where_->Accept(*this);
// Restore the old symbol or just remove the newly bound if there was no
// symbol before.
if (prev_symbol) {
scope_.symbols[all.identifier_->name_] = *prev_symbol;
} else {
scope_.symbols.erase(all.identifier_->name_);
}
return false; return false;
} }
@ -379,6 +363,60 @@ bool SymbolGenerator::PostVisit(EdgeAtom &) {
return true; return true;
} }
void SymbolGenerator::VisitWithIdentifiers(
Tree &tree, const std::vector<Identifier *> &identifiers) {
std::vector<std::pair<std::experimental::optional<Symbol>, Identifier *>>
prev_symbols;
// Collect previous symbols if they exist.
for (const auto &identifier : identifiers) {
std::experimental::optional<Symbol> prev_symbol;
auto prev_symbol_it = scope_.symbols.find(identifier->name_);
if (prev_symbol_it != scope_.symbols.end()) {
prev_symbol = prev_symbol_it->second;
}
symbol_table_[*identifier] = CreateSymbol(identifier->name_, true);
prev_symbols.emplace_back(prev_symbol, identifier);
}
// Visit the tree with the new symbols bound.
tree.Accept(*this);
// Restore back to previous symbols.
for (const auto &prev : prev_symbols) {
const auto &prev_symbol = prev.first;
const auto &identifier = prev.second;
if (prev_symbol) {
scope_.symbols[identifier->name_] = *prev_symbol;
} else {
scope_.symbols.erase(identifier->name_);
}
}
}
bool SymbolGenerator::PreVisit(BreadthFirstAtom &bf_atom) {
scope_.visiting_edge = &bf_atom;
if (scope_.in_create || scope_.in_merge) {
throw SemanticException("BFS cannot be used to create edges.");
}
// Visiting BFS filter and max_depth expressions is not a pattern.
scope_.in_pattern = false;
bf_atom.max_depth_->Accept(*this);
VisitWithIdentifiers(
*bf_atom.filter_expression_,
{bf_atom.traversed_edge_identifier_, bf_atom.next_node_identifier_});
scope_.in_pattern = true;
// XXX: Make BFS symbol be EdgeList.
bf_atom.has_range_ = true;
scope_.in_pattern_identifier = true;
bf_atom.identifier_->Accept(*this);
scope_.in_pattern_identifier = false;
bf_atom.has_range_ = false;
return false;
}
bool SymbolGenerator::PostVisit(BreadthFirstAtom &bf_atom) {
scope_.visiting_edge = nullptr;
return true;
}
bool SymbolGenerator::HasSymbol(const std::string &name) { bool SymbolGenerator::HasSymbol(const std::string &name) {
return scope_.symbols.find(name) != scope_.symbols.end(); return scope_.symbols.find(name) != scope_.symbols.end();
} }

View File

@ -53,6 +53,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
bool PostVisit(NodeAtom &) override; bool PostVisit(NodeAtom &) override;
bool PreVisit(EdgeAtom &) override; bool PreVisit(EdgeAtom &) override;
bool PostVisit(EdgeAtom &) override; bool PostVisit(EdgeAtom &) override;
bool PreVisit(BreadthFirstAtom &) override;
bool PostVisit(BreadthFirstAtom &) override;
private: private:
// Scope stores the state of where we are when visiting the AST and a map of // Scope stores the state of where we are when visiting the AST and a map of
@ -108,6 +110,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
void VisitReturnBody(ReturnBody &body, Where *where = nullptr); void VisitReturnBody(ReturnBody &body, Where *where = nullptr);
void VisitWithIdentifiers(Tree &, const std::vector<Identifier *> &);
SymbolTable &symbol_table_; SymbolTable &symbol_table_;
Scope scope_; Scope scope_;
}; };

View File

@ -1010,4 +1010,74 @@ TEST(TestSymbolGenerator, WithReturnAll) {
EXPECT_NE(symbol_table.at(*all->identifier_), symbol_table.at(*ret_as_x)); EXPECT_NE(symbol_table.at(*all->identifier_), symbol_table.at(*ret_as_x));
} }
TEST(TestSymbolGenerator, MatchBfsReturn) {
// Test MATCH (n) -bfs[r](r, n | r.prop, n.prop)-> (m) RETURN r AS r
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto *node_n = NODE("n");
auto *r_prop = PROPERTY_LOOKUP("r", prop);
auto *n_prop = PROPERTY_LOOKUP("n", prop);
auto *bfs =
storage.Create<BreadthFirstAtom>(IDENT("r"), EdgeAtom::Direction::OUT,
IDENT("r"), IDENT("n"), r_prop, n_prop);
auto *ret_r = IDENT("r");
auto *query =
QUERY(MATCH(PATTERN(node_n, bfs, NODE("m"))), RETURN(ret_r, AS("r")));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for `n`, `[r]`, `r|`, `n|`, `m` and `AS r`.
EXPECT_EQ(symbol_table.max_position(), 6);
EXPECT_EQ(symbol_table.at(*ret_r), symbol_table.at(*bfs->identifier_));
EXPECT_NE(symbol_table.at(*ret_r),
symbol_table.at(*bfs->traversed_edge_identifier_));
EXPECT_EQ(symbol_table.at(*bfs->traversed_edge_identifier_),
symbol_table.at(*r_prop->expression_));
EXPECT_NE(symbol_table.at(*node_n->identifier_),
symbol_table.at(*bfs->next_node_identifier_));
EXPECT_EQ(symbol_table.at(*node_n->identifier_),
symbol_table.at(*n_prop->expression_));
}
TEST(TestSymbolGenerator, MatchBfsUsesEdgeSymbolError) {
// Test MATCH (n) -bfs[r](e, n | r, 10)-> (m) RETURN r
AstTreeStorage storage;
auto *bfs = storage.Create<BreadthFirstAtom>(
IDENT("r"), EdgeAtom::Direction::OUT, IDENT("e"), IDENT("n"), IDENT("r"),
LITERAL(10));
auto *query = QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r"));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
}
TEST(TestSymbolGenerator, MatchBfsUsesPreviousOuterSymbol) {
// Test MATCH (a) -bfs[r](e, n | a, 10)-> (m) RETURN r
AstTreeStorage storage;
auto *node_a = NODE("a");
auto *bfs = storage.Create<BreadthFirstAtom>(
IDENT("r"), EdgeAtom::Direction::OUT, IDENT("e"), IDENT("n"), IDENT("a"),
LITERAL(10));
auto *query = QUERY(MATCH(PATTERN(node_a, bfs, NODE("m"))), RETURN("r"));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
EXPECT_EQ(symbol_table.at(*node_a->identifier_),
symbol_table.at(*bfs->filter_expression_));
}
TEST(TestSymbolGenerator, MatchBfsUsesLaterSymbolError) {
// Test MATCH (n) -bfs[r](e, n | m, 10)-> (m) RETURN r
AstTreeStorage storage;
auto *bfs = storage.Create<BreadthFirstAtom>(
IDENT("r"), EdgeAtom::Direction::OUT, IDENT("e"), IDENT("n"), IDENT("m"),
LITERAL(10));
auto *query = QUERY(MATCH(PATTERN(NODE("n"), bfs, NODE("m"))), RETURN("r"));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), UnboundVariableError);
}
} // namespace } // namespace