From 8da6ce67c04fe76e7a7fd16a6cc2de013c112bef Mon Sep 17 00:00:00 2001 From: Teon Banek <theongugl@gmail.com> Date: Wed, 22 Mar 2017 13:09:38 +0100 Subject: [PATCH] Add planning CreateExpand operator Summary: Add planning CreateExpand operator. This is quite similar to planning Expand, but I wouldn't abstract the duplicated parts yet. Also, raise semantic error if creating bidirectional edges Reviewers: buda, mislav.bradac, florijan Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D155 --- src/query/frontend/logical/planner.cpp | 190 ++++++++++-------- .../frontend/semantic/symbol_generator.hpp | 8 +- tests/unit/query_planner.cpp | 50 +++++ tests/unit/query_semantic.cpp | 33 ++- 4 files changed, 189 insertions(+), 92 deletions(-) diff --git a/src/query/frontend/logical/planner.cpp b/src/query/frontend/logical/planner.cpp index 1e811e330..9d9c77e60 100644 --- a/src/query/frontend/logical/planner.cpp +++ b/src/query/frontend/logical/planner.cpp @@ -9,41 +9,69 @@ namespace query { namespace { -static LogicalOperator *GenCreate( - Create& create, std::shared_ptr<LogicalOperator> input_op) -{ - if (input_op) { - // TODO: Support clauses before CREATE, e.g. `MATCH (n) CREATE (m)` - throw NotYetImplemented(); - } - if (create.patterns_.size() != 1) { - // TODO: Support creating multiple patterns, e.g. `CREATE (n), (m)` - throw NotYetImplemented(); - } - auto &pattern = create.patterns_[0]; - if (pattern->atoms_.size() != 1) { - // TODO: Support creating edges. - throw NotYetImplemented(); - } - auto *node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]); - debug_assert(node_atom, "First pattern atom is not a node"); - return new CreateOp(node_atom); -} - // Returns false if the symbol was already bound, otherwise binds it and // returns true. -bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol) -{ +bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol) { auto insertion = bound_symbols.insert(symbol.position_); return insertion.second; } -LogicalOperator *GenMatch( - Match& match, - std::shared_ptr<LogicalOperator> input_op, - const SymbolTable &symbol_table, - std::unordered_set<int> &bound_symbols) -{ +LogicalOperator *GenCreateForPattern(Pattern &pattern, + LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set<int> bound_symbols) { + auto atoms_it = pattern.atoms_.begin(); + auto last_node = dynamic_cast<NodeAtom *>(*atoms_it++); + debug_assert(last_node, "First pattern atom is not a node"); + auto last_op = input_op; + if (BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { + // TODO: Pass last_op when CreateOp gets support for it. This will + // support e.g. `MATCH (n) CREATE (m)` and `CREATE (n), (m)`. + if (last_op) { + throw NotYetImplemented(); + } + last_op = new CreateOp(last_node); + } + // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)* + while (atoms_it != pattern.atoms_.end()) { + auto edge = dynamic_cast<EdgeAtom *>(*atoms_it++); + debug_assert(edge, "Expected an edge atom in pattern."); + debug_assert(atoms_it != pattern.atoms_.end(), + "Edge atom should not end the pattern."); + // Store the symbol from the first node as the input to CreateExpand. + auto input_symbol = symbol_table.at(*last_node->identifier_); + last_node = dynamic_cast<NodeAtom *>(*atoms_it++); + debug_assert(last_node, "Expected a node atom in pattern."); + // If the expand node was already bound, then we need to indicate this, + // so that CreateExpand only creates an edge. + bool node_existing = false; + if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { + node_existing = true; + } + if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) { + permanent_fail("Symbols used for created edges cannot be redeclared."); + } + last_op = new CreateExpand(last_node, edge, + std::shared_ptr<LogicalOperator>(last_op), + input_symbol, node_existing); + } + return last_op; +} + +LogicalOperator *GenCreate(Create &create, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set<int> bound_symbols) { + auto last_op = input_op; + for (auto pattern : create.patterns_) { + last_op = + GenCreateForPattern(*pattern, last_op, symbol_table, bound_symbols); + } + return last_op; +} + +LogicalOperator *GenMatch(Match &match, LogicalOperator *input_op, + const SymbolTable &symbol_table, + std::unordered_set<int> &bound_symbols) { if (input_op) { // TODO: Support clauses before match. throw NotYetImplemented(); @@ -55,7 +83,7 @@ LogicalOperator *GenMatch( auto &pattern = match.patterns_[0]; debug_assert(!pattern->atoms_.empty(), "Missing atoms in pattern"); auto atoms_it = pattern->atoms_.begin(); - auto last_node = dynamic_cast<NodeAtom*>(*atoms_it++); + auto last_node = dynamic_cast<NodeAtom *>(*atoms_it++); debug_assert(last_node, "First pattern atom is not a node"); // First atom always binds a symbol, and we don't care if it already existed, // because we create a ScanAll which writes that symbol. This may need to @@ -63,67 +91,60 @@ LogicalOperator *GenMatch( BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_)); LogicalOperator *last_op = new ScanAll(last_node); if (!last_node->labels_.empty() || !last_node->properties_.empty()) { - last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op), - symbol_table.at(*last_node->identifier_), - last_node); + last_op = + new NodeFilter(std::shared_ptr<LogicalOperator>(last_op), + symbol_table.at(*last_node->identifier_), last_node); } - EdgeAtom *last_edge = nullptr; // Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)* - for ( ; atoms_it != pattern->atoms_.end(); ++atoms_it) { - if (last_edge) { - // Store the symbol from the first node as the input to Expand. - auto input_symbol = symbol_table.at(*last_node->identifier_); - last_node = dynamic_cast<NodeAtom*>(*atoms_it); - debug_assert(last_node, "Expected a node atom in pattern."); - // If the expand symbols were already bound, then we need to indicate - // this as a cycle. The Expand will then check whether the pattern holds - // instead of writing the expansion to symbols. - auto node_cycle = false; - auto edge_cycle = false; - if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { - node_cycle = true; - } - if (!BindSymbol(bound_symbols, symbol_table.at(*last_edge->identifier_))) { - edge_cycle = true; - } - last_op = new Expand(last_node, last_edge, - std::shared_ptr<LogicalOperator>(last_op), - input_symbol, node_cycle, edge_cycle); - if (!last_edge->edge_types_.empty()) { - last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op), - symbol_table.at(*last_edge->identifier_), - last_edge); - } - if (!last_node->labels_.empty() || !last_node->properties_.empty()) { - last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op), - symbol_table.at(*last_node->identifier_), - last_node); - } - // Don't forget to clear the edge, because we expect the next - // (EdgeAtom, NodeAtom) sequence. - last_edge = nullptr; - } else { - last_edge = dynamic_cast<EdgeAtom*>(*atoms_it); - debug_assert(last_edge, "Expected an edge atom in pattern."); + while (atoms_it != pattern->atoms_.end()) { + auto edge = dynamic_cast<EdgeAtom *>(*atoms_it++); + debug_assert(edge, "Expected an edge atom in pattern."); + debug_assert(atoms_it != pattern->atoms_.end(), + "Edge atom should not end the pattern."); + // Store the symbol from the first node as the input to Expand. + auto input_symbol = symbol_table.at(*last_node->identifier_); + last_node = dynamic_cast<NodeAtom *>(*atoms_it++); + debug_assert(last_node, "Expected a node atom in pattern."); + // If the expand symbols were already bound, then we need to indicate + // this as a cycle. The Expand will then check whether the pattern holds + // instead of writing the expansion to symbols. + auto node_cycle = false; + auto edge_cycle = false; + if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) { + node_cycle = true; + } + if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) { + edge_cycle = true; + } + last_op = + new Expand(last_node, edge, std::shared_ptr<LogicalOperator>(last_op), + input_symbol, node_cycle, edge_cycle); + if (!edge->edge_types_.empty() || !edge->properties_.empty()) { + last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op), + symbol_table.at(*edge->identifier_), edge); + } + if (!last_node->labels_.empty() || !last_node->properties_.empty()) { + last_op = + new NodeFilter(std::shared_ptr<LogicalOperator>(last_op), + symbol_table.at(*last_node->identifier_), last_node); } } - debug_assert(!last_edge, "Edge atom should not end the pattern."); return last_op; } -Produce *GenReturn(Return& ret, std::shared_ptr<LogicalOperator> input_op) -{ +Produce *GenReturn(Return &ret, LogicalOperator *input_op) { if (!input_op) { // TODO: Support standalone RETURN clause (e.g. RETURN 2) throw NotYetImplemented(); } - return new Produce(input_op, ret.named_expressions_); -} + return new Produce(std::shared_ptr<LogicalOperator>(input_op), + ret.named_expressions_); } +} // namespace + std::unique_ptr<LogicalOperator> MakeLogicalPlan( - Query& query, const SymbolTable &symbol_table) -{ + Query &query, const SymbolTable &symbol_table) { // TODO: Extract functions and state into a class with methods. Possibly a // visitor or similar to avoid all those dynamic casts. LogicalOperator *input_op = nullptr; @@ -134,13 +155,12 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan( std::unordered_set<int> bound_symbols; for (auto &clause : query.clauses_) { auto *clause_ptr = clause; - if (auto *match = dynamic_cast<Match*>(clause_ptr)) { - input_op = GenMatch(*match, std::shared_ptr<LogicalOperator>(input_op), - symbol_table, bound_symbols); - } else if (auto *ret = dynamic_cast<Return*>(clause_ptr)) { - input_op = GenReturn(*ret, std::shared_ptr<LogicalOperator>(input_op)); - } else if (auto *create = dynamic_cast<Create*>(clause_ptr)) { - input_op = GenCreate(*create, std::shared_ptr<LogicalOperator>(input_op)); + if (auto *match = dynamic_cast<Match *>(clause_ptr)) { + input_op = GenMatch(*match, input_op, symbol_table, bound_symbols); + } else if (auto *ret = dynamic_cast<Return *>(clause_ptr)) { + input_op = GenReturn(*ret, input_op); + } else if (auto *create = dynamic_cast<Create *>(clause_ptr)) { + input_op = GenCreate(*create, input_op, symbol_table, bound_symbols); } else { throw NotYetImplemented(); } @@ -148,4 +168,4 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan( return std::unique_ptr<LogicalOperator>(input_op); } -} +} // namespace query diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 9c44557f6..bf0558e34 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -101,6 +101,10 @@ class SymbolGenerator : public TreeVisitorBase { throw SemanticException("A single relationship type must be specified " "when creating an edge."); } + if (edge_atom.direction_ == EdgeAtom::Direction::BOTH) { + throw SemanticException("Bidirectional relationship are not supported " + "when creating an edge"); + } } } void PostVisit(EdgeAtom &edge_atom) override { @@ -112,9 +116,7 @@ class SymbolGenerator : public TreeVisitorBase { // A variable stores the associated symbol and its type. struct Variable { // This is similar to TypedValue::Type, but this has `Any` type. - enum class Type : unsigned { - Any, Vertex, Edge, Path - }; + enum class Type { Any, Vertex, Edge, Path }; Symbol symbol; Type type{Type::Any}; diff --git a/tests/unit/query_planner.cpp b/tests/unit/query_planner.cpp index 4c846c33e..9ecf56381 100644 --- a/tests/unit/query_planner.cpp +++ b/tests/unit/query_planner.cpp @@ -21,6 +21,7 @@ class PlanChecker : public LogicalOperatorVisitor { PlanChecker(std::list<size_t> types) : types_(types) {} void Visit(CreateOp &op) override { AssertType(op); } + void Visit(CreateExpand &op) override { AssertType(op); } void Visit(ScanAll &op) override { AssertType(op); } void Visit(Expand &op) override { AssertType(op); } void Visit(NodeFilter &op) override { AssertType(op); } @@ -102,6 +103,55 @@ TEST(TestLogicalPlanner, CreateNodeReturn) { plan->Accept(plan_checker); } +TEST(TestLogicalPlanner, CreateExpand) { + // Test CREATE (n) -[r :rel1]-> (m) + AstTreeStorage storage; + auto create = storage.Create<Create>(); + auto pattern = GetPattern(storage, {"n", "r", "m"}); + create->patterns_.emplace_back(pattern); + auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]); + edge_atom->direction_ = EdgeAtom::Direction::RIGHT; + std::string relationship("relationship"); + edge_atom->edge_types_.emplace_back(&relationship); + auto query = storage.query(); + query->clauses_.emplace_back(create); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + auto plan = MakeLogicalPlan(*query, symbol_table); + std::list<size_t> expected_types; + expected_types.emplace_back(typeid(CreateOp).hash_code()); + expected_types.emplace_back(typeid(CreateExpand).hash_code()); + PlanChecker plan_checker(expected_types); + plan->Accept(plan_checker); +} + +TEST(TestLogicalPlanner, MatchCreateExpand) { + // Test MATCH (n) CREATE (n) -[r :rel1]-> (m) + AstTreeStorage storage; + auto match = storage.Create<Match>(); + match->patterns_.emplace_back(GetPattern(storage, {"n"})); + auto query = storage.query(); + query->clauses_.emplace_back(match); + auto create = storage.Create<Create>(); + auto pattern = GetPattern(storage, {"n", "r", "m"}); + create->patterns_.emplace_back(pattern); + auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]); + edge_atom->direction_ = EdgeAtom::Direction::RIGHT; + std::string relationship("relationship"); + edge_atom->edge_types_.emplace_back(&relationship); + query->clauses_.emplace_back(create); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + query->Accept(symbol_generator); + auto plan = MakeLogicalPlan(*query, symbol_table); + std::list<size_t> expected_types; + expected_types.emplace_back(typeid(ScanAll).hash_code()); + expected_types.emplace_back(typeid(CreateExpand).hash_code()); + PlanChecker plan_checker(expected_types); + plan->Accept(plan_checker); +} + TEST(TestLogicalPlanner, MatchLabeledNodes) { // Test MATCH (n :label) RETURN n AS n AstTreeStorage storage; diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index f6bbbc036..467f7cb26 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -320,7 +320,7 @@ TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) { SymbolTable symbol_table; AstTreeStorage storage; // AST with redeclaring a match edge variable in create: - // MATCH (n) -[r]-> (m) CREATE (n) -[r] -> (l) + // MATCH (n) -[r]- (m) CREATE (n) -[r] -> (l) auto match = storage.Create<Match>(); match->patterns_.emplace_back(GetPattern(storage, {"n", "r", "m"})); auto query = storage.query(); @@ -329,6 +329,7 @@ TEST(TestSymbolGenerator, MatchCreateRedeclareEdge) { auto create = storage.Create<Create>(); auto pattern = GetPattern(storage, {"n", "r", "l"}); auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]); + edge_atom->direction_ = EdgeAtom::Direction::RIGHT; std::string relationship("relationship"); edge_atom->edge_types_.emplace_back(&relationship); create->patterns_.emplace_back(pattern); @@ -353,13 +354,16 @@ TEST(TestSymbolGenerator, MatchTypeMismatch) { TEST(TestSymbolGenerator, MatchCreateTypeMismatch) { AstTreeStorage storage; // Using an edge variable as a node causes a type mismatch. - // MATCH (n1) -[r1]- (n2) CREATE (r1) -[r2]- (n2) + // MATCH (n1) -[r1]- (n2) CREATE (r1) -[r2]-> (n2) auto match = storage.Create<Match>(); match->patterns_.emplace_back(GetPattern(storage, {"n1", "r1", "n2"})); auto query = storage.query(); query->clauses_.emplace_back(match); auto create = storage.Create<Create>(); - create->patterns_.emplace_back(GetPattern(storage, {"r1", "r2", "n2"})); + auto pattern = GetPattern(storage, {"r1", "r2", "n2"}); + create->patterns_.emplace_back(pattern); + auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]); + edge_atom->direction_ = EdgeAtom::Direction::RIGHT; query->clauses_.emplace_back(create); SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); @@ -369,9 +373,30 @@ TEST(TestSymbolGenerator, MatchCreateTypeMismatch) { TEST(TestSymbolGenerator, CreateMultipleEdgeType) { AstTreeStorage storage; // Multiple edge relationship are not allowed when creating edges. - // CREATE (n) -[r :rel1 | :rel2]- (m) + // CREATE (n) -[r :rel1 | :rel2]-> (m) auto pattern = GetPattern(storage, {"n", "r", "m"}); auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]); + edge_atom->direction_ = EdgeAtom::Direction::RIGHT; + std::string rel1("rel1"); + edge_atom->edge_types_.emplace_back(&rel1); + std::string rel2("rel2"); + edge_atom->edge_types_.emplace_back(&rel2); + auto create = storage.Create<Create>(); + create->patterns_.emplace_back(pattern); + auto query = storage.query(); + query->clauses_.emplace_back(create); + SymbolTable symbol_table; + SymbolGenerator symbol_generator(symbol_table); + EXPECT_THROW(query->Accept(symbol_generator), SemanticException); +} + +TEST(TestSymbolGenerator, CreateBidirectionalEdge) { + AstTreeStorage storage; + // Bidirectional relationships are not allowed when creating edges. + // CREATE (n) -[r :rel1]- (m) + auto pattern = GetPattern(storage, {"n", "r", "m"}); + auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]); + edge_atom->direction_ = EdgeAtom::Direction::BOTH; std::string rel1("rel1"); edge_atom->edge_types_.emplace_back(&rel1); std::string rel2("rel2");