Plan Merge operator

Summary:
Check symbols in Merge.
Support MERGE macro in query tests.
Test SymbolGenerator with MERGE.
Test planning Merge.

Reviewers: florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D317
This commit is contained in:
Teon Banek 2017-04-26 13:49:41 +02:00
parent 0fa4555cad
commit 8fa574026e
9 changed files with 274 additions and 66 deletions

View File

@ -111,6 +111,9 @@ bool SymbolGenerator::PreVisit(With &with) {
void SymbolGenerator::Visit(Where &) { scope_.in_where = true; } void SymbolGenerator::Visit(Where &) { scope_.in_where = true; }
void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; } void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; }
void SymbolGenerator::Visit(Merge &) { scope_.in_merge = true; }
void SymbolGenerator::PostVisit(Merge &) { scope_.in_merge = false; }
// Expressions // Expressions
void SymbolGenerator::Visit(Identifier &ident) { void SymbolGenerator::Visit(Identifier &ident) {
@ -178,7 +181,7 @@ void SymbolGenerator::PostVisit(Aggregation &aggr) {
void SymbolGenerator::Visit(Pattern &pattern) { void SymbolGenerator::Visit(Pattern &pattern) {
scope_.in_pattern = true; scope_.in_pattern = true;
if (scope_.in_create && pattern.atoms_.size() == 1U) { if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) {
debug_assert(dynamic_cast<NodeAtom *>(pattern.atoms_[0]), debug_assert(dynamic_cast<NodeAtom *>(pattern.atoms_[0]),
"Expected a single NodeAtom in Pattern"); "Expected a single NodeAtom in Pattern");
scope_.in_create_node = true; scope_.in_create_node = true;
@ -213,14 +216,15 @@ void SymbolGenerator::PostVisit(NodeAtom &node_atom) {
void SymbolGenerator::Visit(EdgeAtom &edge_atom) { void SymbolGenerator::Visit(EdgeAtom &edge_atom) {
scope_.in_edge_atom = true; scope_.in_edge_atom = true;
if (scope_.in_create) { if (scope_.in_create || scope_.in_merge) {
scope_.in_create_edge = true; scope_.in_create_edge = true;
if (edge_atom.edge_types_.size() != 1U) { if (edge_atom.edge_types_.size() != 1U) {
throw SemanticException( throw SemanticException(
"A single relationship type must be specified " "A single relationship type must be specified "
"when creating an edge."); "when creating an edge.");
} }
if (edge_atom.direction_ == EdgeAtom::Direction::BOTH) { if (scope_.in_create && // Merge allows bidirectionality
edge_atom.direction_ == EdgeAtom::Direction::BOTH) {
throw SemanticException( throw SemanticException(
"Bidirectional relationship are not supported " "Bidirectional relationship are not supported "
"when creating an edge"); "when creating an edge");

View File

@ -31,6 +31,8 @@ class SymbolGenerator : public TreeVisitorBase {
bool PreVisit(With &) override; bool PreVisit(With &) override;
void Visit(Where &) override; void Visit(Where &) override;
void PostVisit(Where &) override; void PostVisit(Where &) override;
void Visit(Merge &) override;
void PostVisit(Merge &) override;
// Expressions // Expressions
void Visit(Identifier &) override; void Visit(Identifier &) override;
@ -50,11 +52,14 @@ class SymbolGenerator : public TreeVisitorBase {
// names to symbols. // names to symbols.
struct Scope { struct Scope {
bool in_pattern{false}; bool in_pattern{false};
bool in_merge{false};
bool in_create{false}; bool in_create{false};
// in_create_node is true if we are creating *only* a node. Therefore, it // in_create_node is true if we are creating or merging *only* a node.
// is *not* equivalent to in_create && in_node_atom. // Therefore, it is *not* equivalent to (in_create || in_merge) &&
// in_node_atom.
bool in_create_node{false}; bool in_create_node{false};
// True if creating an edge; shortcut for in_create && in_edge_atom. // True if creating an edge;
// shortcut for (in_create || in_merge) && in_edge_atom.
bool in_create_edge{false}; bool in_create_edge{false};
bool in_node_atom{false}; bool in_node_atom{false};
bool in_edge_atom{false}; bool in_edge_atom{false};

View File

@ -77,7 +77,8 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
dynamic_cast<plan::SetLabels *>(logical_plan.get()) || dynamic_cast<plan::SetLabels *>(logical_plan.get()) ||
dynamic_cast<plan::RemoveProperty *>(logical_plan.get()) || dynamic_cast<plan::RemoveProperty *>(logical_plan.get()) ||
dynamic_cast<plan::RemoveLabels *>(logical_plan.get()) || dynamic_cast<plan::RemoveLabels *>(logical_plan.get()) ||
dynamic_cast<plan::Delete *>(logical_plan.get())) { dynamic_cast<plan::Delete *>(logical_plan.get()) ||
dynamic_cast<plan::Merge *>(logical_plan.get())) {
stream.Header(header); stream.Header(header);
auto cursor = logical_plan->MakeCursor(db_accessor); auto cursor = logical_plan->MakeCursor(db_accessor);
while (cursor->Pull(frame, symbol_table)) continue; while (cursor->Pull(frame, symbol_table)) continue;

View File

@ -85,7 +85,7 @@ CreateExpand::CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom,
Symbol input_symbol, bool existing_node) Symbol input_symbol, bool existing_node)
: node_atom_(node_atom), : node_atom_(node_atom),
edge_atom_(edge_atom), edge_atom_(edge_atom),
input_(input), input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol), input_symbol_(input_symbol),
existing_node_(existing_node) {} existing_node_(existing_node) {}
@ -223,7 +223,7 @@ Expand::Expand(const NodeAtom *node_atom, const EdgeAtom *edge_atom,
GraphView graph_view) GraphView graph_view)
: node_atom_(node_atom), : node_atom_(node_atom),
edge_atom_(edge_atom), edge_atom_(edge_atom),
input_(input), input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol), input_symbol_(input_symbol),
existing_node_(existing_node), existing_node_(existing_node),
existing_edge_(existing_edge), existing_edge_(existing_edge),

View File

@ -185,7 +185,7 @@ class CreateExpand : public LogicalOperator {
* @param node_atom @c NodeAtom at the end of the edge. Used to create a node, * @param node_atom @c NodeAtom at the end of the edge. Used to create a node,
* unless it refers to an existing one. * unless it refers to an existing one.
* @param edge_atom @c EdgeAtom with information for the edge to be created. * @param edge_atom @c EdgeAtom with information for the edge to be created.
* @param input Required. Previous @c LogicalOperator which will be pulled. * @param input Optional. Previous @c LogicalOperator which will be pulled.
* For each successful @c Cursor::Pull, this operator will create an * For each successful @c Cursor::Pull, this operator will create an
* expansion. * expansion.
* @param input_symbol @c Symbol for the node at the start of the edge. * @param input_symbol @c Symbol for the node at the start of the edge.
@ -321,7 +321,7 @@ class Expand : public LogicalOperator {
* identifier is used, labels and properties are ignored. * identifier is used, labels and properties are ignored.
* @param edge_atom Describes the edge to be expanded. Identifier * @param edge_atom Describes the edge to be expanded. Identifier
* and direction are used, edge type and properties are ignored. * and direction are used, edge type and properties are ignored.
* @param input LogicalOperation that preceeds this one. * @param input Optional LogicalOperator that preceeds this one.
* @param input_symbol Symbol that points to a VertexAccessor * @param input_symbol Symbol that points to a VertexAccessor
* in the Frame that expansion should emanate from. * in the Frame that expansion should emanate from.
* @param existing_node If or not the node to be expanded is already * @param existing_node If or not the node to be expanded is already
@ -1165,6 +1165,10 @@ class Merge : public LogicalOperator {
void Accept(LogicalOperatorVisitor &visitor) override; void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override; std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
auto input() const { return input_; }
auto merge_match() const { return merge_match_; }
auto merge_create() const { return merge_create_; }
private: private:
const std::shared_ptr<LogicalOperator> input_; const std::shared_ptr<LogicalOperator> input_;
const std::shared_ptr<LogicalOperator> merge_match_; const std::shared_ptr<LogicalOperator> merge_match_;

View File

@ -7,8 +7,7 @@
#include "query/frontend/ast/ast.hpp" #include "query/frontend/ast/ast.hpp"
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
namespace query { namespace query::plan {
namespace plan {
namespace { namespace {
@ -105,13 +104,15 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op, auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table, const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols, std::unordered_set<int> &bound_symbols,
std::vector<Symbol> &edge_symbols) { std::vector<Symbol> &edge_symbols,
GraphView graph_view = GraphView::OLD) {
auto base = [&](NodeAtom *node) { auto base = [&](NodeAtom *node) {
LogicalOperator *last_op = input_op; LogicalOperator *last_op = input_op;
// If the first atom binds a symbol, we generate a ScanAll which writes it. // If the first atom binds a symbol, we generate a ScanAll which writes it.
// Otherwise, someone else generates it (e.g. a previous ScanAll). // Otherwise, someone else generates it (e.g. a previous ScanAll).
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op)); last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op),
graph_view);
} }
// Even though we may skip generating ScanAll, we still want to add a filter // Even though we may skip generating ScanAll, we still want to add a filter
// in case this atom adds more labels/properties for filtering. // in case this atom adds more labels/properties for filtering.
@ -126,20 +127,21 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
// Store the symbol from the first node as the input to Expand. // Store the symbol from the first node as the input to Expand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_); const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand symbols were already bound, then we need to indicate // If the expand symbols were already bound, then we need to indicate
// this as a cycle. The Expand will then check whether the pattern holds // that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols. // instead of writing the expansion to symbols.
auto node_cycle = false; auto existing_node = false;
auto edge_cycle = false; auto existing_edge = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) { if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
node_cycle = true; existing_node = true;
} }
const auto &edge_symbol = symbol_table.at(*edge->identifier_); const auto &edge_symbol = symbol_table.at(*edge->identifier_);
if (!BindSymbol(bound_symbols, edge_symbol)) { if (!BindSymbol(bound_symbols, edge_symbol)) {
edge_cycle = true; existing_edge = true;
} }
last_op = new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op), last_op =
input_symbol, node_cycle, edge_cycle); new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
if (!edge_cycle) { input_symbol, existing_node, existing_edge, graph_view);
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to different // Ensure Cyphermorphism (different edge symbols always map to different
// edges). // edges).
if (!edge_symbols.empty()) { if (!edge_symbols.empty()) {
@ -455,6 +457,32 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
return nullptr; return nullptr;
} }
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
// Copy the bound symbol set, because we don't want to use the updated version
// when generating the create part.
std::unordered_set<int> bound_symbols_copy(bound_symbols);
std::vector<Symbol> edge_symbols;
auto on_match =
GenMatchForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols_copy, edge_symbols, GraphView::NEW);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols);
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, symbol_table, bound_symbols);
debug_assert(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, symbol_table, bound_symbols);
debug_assert(on_match, "Expected SET in MERGE ... ON MATCH");
}
return new plan::Merge(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(on_match),
std::shared_ptr<LogicalOperator>(on_create));
}
} // namespace } // namespace
std::unique_ptr<LogicalOperator> MakeLogicalPlan( std::unique_ptr<LogicalOperator> MakeLogicalPlan(
@ -464,7 +492,7 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first // or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information. // `n`, but the latter `n` would only read the already written information.
std::unordered_set<int> bound_symbols; std::unordered_set<int> bound_symbols;
// Set to true if a query command performs a writes to the database. // Set to true if a query command writes to the database.
bool is_write = false; bool is_write = false;
LogicalOperator *input_op = nullptr; LogicalOperator *input_op = nullptr;
for (auto &clause : query.clauses_) { for (auto &clause : query.clauses_) {
@ -473,6 +501,11 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
input_op = GenMatch(*match, input_op, symbol_table, bound_symbols); input_op = GenMatch(*match, input_op, symbol_table, bound_symbols);
} else if (auto *ret = dynamic_cast<Return *>(clause)) { } else if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op = GenReturn(*ret, input_op, symbol_table, is_write); input_op = GenReturn(*ret, input_op, symbol_table, is_write);
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
input_op = GenMerge(*merge, input_op, symbol_table, bound_symbols);
// Treat MERGE clause as write, because we do not know if it will create
// anything.
is_write = true;
} else if (auto *with = dynamic_cast<query::With *>(clause)) { } else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op = input_op =
GenWith(*with, input_op, symbol_table, is_write, bound_symbols); GenWith(*with, input_op, symbol_table, is_write, bound_symbols);
@ -489,5 +522,4 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
return std::unique_ptr<LogicalOperator>(input_op); return std::unique_ptr<LogicalOperator>(input_op);
} }
} // namespace plan } // namespace query::plan
} // namespace query

View File

@ -28,8 +28,8 @@ namespace query {
namespace test_common { namespace test_common {
// Custom types for ORDER BY, SKIP and LIMIT and expressions, so that they can // Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions,
// be used to resolve function calls. // so that they can be used to resolve function calls.
struct OrderBy { struct OrderBy {
std::vector<std::pair<Ordering, Expression *>> expressions; std::vector<std::pair<Ordering, Expression *>> expressions;
}; };
@ -39,6 +39,12 @@ struct Skip {
struct Limit { struct Limit {
Expression *expression = nullptr; Expression *expression = nullptr;
}; };
struct OnMatch {
std::vector<Clause *> set;
};
struct OnCreate {
std::vector<Clause *> set;
};
// Helper functions for filling the OrderBy with expressions. // Helper functions for filling the OrderBy with expressions.
auto FillOrderBy(OrderBy &order_by, Expression *expression, auto FillOrderBy(OrderBy &order_by, Expression *expression,
@ -301,6 +307,26 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels); return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels);
} }
///
/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate
/// parts.
///
auto GetMerge(AstTreeStorage &storage, Pattern *pattern,
OnCreate on_create = OnCreate{}) {
auto *merge = storage.Create<query::Merge>();
merge->pattern_ = pattern;
merge->on_create_ = on_create.set;
return merge;
}
auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
OnCreate on_create = OnCreate{}) {
auto *merge = storage.Create<query::Merge>();
merge->pattern_ = pattern;
merge->on_match_ = on_match.set;
merge->on_create_ = on_create.set;
return merge;
}
} // namespace test_common } // namespace test_common
} // namespace query } // namespace query
@ -346,6 +372,15 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
query::test_common::GetDelete(storage, {__VA_ARGS__}, true) query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
#define SET(...) query::test_common::GetSet(storage, __VA_ARGS__) #define SET(...) query::test_common::GetSet(storage, __VA_ARGS__)
#define REMOVE(...) query::test_common::GetRemove(storage, __VA_ARGS__) #define REMOVE(...) query::test_common::GetRemove(storage, __VA_ARGS__)
#define MERGE(...) query::test_common::GetMerge(storage, __VA_ARGS__)
#define ON_MATCH(...) \
query::test_common::OnMatch { \
std::vector<query::Clause *> { __VA_ARGS__ } \
}
#define ON_CREATE(...) \
query::test_common::OnCreate { \
std::vector<query::Clause *> { __VA_ARGS__ } \
}
#define QUERY(...) query::test_common::GetQuery(storage, __VA_ARGS__) #define QUERY(...) query::test_common::GetQuery(storage, __VA_ARGS__)
// Various operators // Various operators
#define ADD(expr1, expr2) \ #define ADD(expr1, expr2) \

View File

@ -29,6 +29,57 @@ class BaseOpChecker {
virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0; virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0;
}; };
class PlanChecker : public LogicalOperatorVisitor {
public:
using LogicalOperatorVisitor::PreVisit;
using LogicalOperatorVisitor::Visit;
using LogicalOperatorVisitor::PostVisit;
PlanChecker(const std::list<BaseOpChecker *> &checkers,
const SymbolTable &symbol_table)
: checkers_(checkers), symbol_table_(symbol_table) {}
void Visit(CreateNode &op) override { CheckOp(op); }
void Visit(CreateExpand &op) override { CheckOp(op); }
void Visit(Delete &op) override { CheckOp(op); }
void Visit(ScanAll &op) override { CheckOp(op); }
void Visit(Expand &op) override { CheckOp(op); }
void Visit(NodeFilter &op) override { CheckOp(op); }
void Visit(EdgeFilter &op) override { CheckOp(op); }
void Visit(Filter &op) override { CheckOp(op); }
void Visit(Produce &op) override { CheckOp(op); }
void Visit(SetProperty &op) override { CheckOp(op); }
void Visit(SetProperties &op) override { CheckOp(op); }
void Visit(SetLabels &op) override { CheckOp(op); }
void Visit(RemoveProperty &op) override { CheckOp(op); }
void Visit(RemoveLabels &op) override { CheckOp(op); }
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
CheckOp(op);
}
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
void Visit(OrderBy &op) override { CheckOp(op); }
bool PreVisit(Merge &op) override {
CheckOp(op);
op.input()->Accept(*this);
return false;
}
std::list<BaseOpChecker *> checkers_;
private:
void CheckOp(LogicalOperator &op) {
ASSERT_FALSE(checkers_.empty());
checkers_.back()->CheckOp(op, symbol_table_);
checkers_.pop_back();
}
const SymbolTable &symbol_table_;
};
template <class TOp> template <class TOp>
class OpChecker : public BaseOpChecker { class OpChecker : public BaseOpChecker {
public: public:
@ -103,49 +154,22 @@ class ExpectAggregate : public OpChecker<Aggregate> {
const std::unordered_set<query::Expression *> group_by_; const std::unordered_set<query::Expression *> group_by_;
}; };
class PlanChecker : public LogicalOperatorVisitor { class ExpectMerge : public OpChecker<Merge> {
public: public:
using LogicalOperatorVisitor::Visit; ExpectMerge(const std::list<BaseOpChecker *> &on_match,
using LogicalOperatorVisitor::PostVisit; const std::list<BaseOpChecker *> &on_create)
: on_match_(on_match), on_create_(on_create) {}
PlanChecker(const std::list<BaseOpChecker *> &checkers, void ExpectOp(Merge &merge, const SymbolTable &symbol_table) override {
const SymbolTable &symbol_table) PlanChecker check_match(on_match_, symbol_table);
: checkers_(checkers), symbol_table_(symbol_table) {} merge.merge_match()->Accept(check_match);
PlanChecker check_create(on_create_, symbol_table);
void Visit(CreateNode &op) override { CheckOp(op); } merge.merge_create()->Accept(check_create);
void Visit(CreateExpand &op) override { CheckOp(op); }
void Visit(Delete &op) override { CheckOp(op); }
void Visit(ScanAll &op) override { CheckOp(op); }
void Visit(Expand &op) override { CheckOp(op); }
void Visit(NodeFilter &op) override { CheckOp(op); }
void Visit(EdgeFilter &op) override { CheckOp(op); }
void Visit(Filter &op) override { CheckOp(op); }
void Visit(Produce &op) override { CheckOp(op); }
void Visit(SetProperty &op) override { CheckOp(op); }
void Visit(SetProperties &op) override { CheckOp(op); }
void Visit(SetLabels &op) override { CheckOp(op); }
void Visit(RemoveProperty &op) override { CheckOp(op); }
void Visit(RemoveLabels &op) override { CheckOp(op); }
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
CheckOp(op);
} }
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
void Visit(OrderBy &op) override { CheckOp(op); }
std::list<BaseOpChecker *> checkers_;
private: private:
void CheckOp(LogicalOperator &op) { const std::list<BaseOpChecker *> &on_match_;
ASSERT_FALSE(checkers_.empty()); const std::list<BaseOpChecker *> &on_create_;
checkers_.back()->CheckOp(op, symbol_table_);
checkers_.pop_back();
}
const SymbolTable &symbol_table_;
}; };
auto MakeSymbolTable(query::Query &query) { auto MakeSymbolTable(query::Query &query) {
@ -571,4 +595,36 @@ TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) {
CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy()); CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy());
} }
TEST(TestLogicalPlanner, MatchMerge) {
// Test MATCH (n) MERGE (n) -[r :r]- (m)
// ON MATCH SET n.prop = 42 ON CREATE SET m = n
// RETURN n AS n
Dbms dbms;
auto dba = dbms.active();
auto r_type = dba->edge_type("r");
auto prop = dba->property("prop");
AstTreeStorage storage;
auto ident_n = IDENT("n");
auto query =
QUERY(MATCH(PATTERN(NODE("n"))),
MERGE(PATTERN(NODE("n"), EDGE("r", r_type), NODE("m")),
ON_MATCH(SET(PROPERTY_LOOKUP("n", prop), LITERAL(42))),
ON_CREATE(SET("m", IDENT("n")))),
RETURN(ident_n, AS("n")));
std::list<BaseOpChecker *> on_match{
new ExpectExpand(), new ExpectEdgeFilter(), new ExpectSetProperty()};
std::list<BaseOpChecker *> on_create{new ExpectCreateExpand(),
new ExpectSetProperties()};
auto symbol_table = MakeSymbolTable(*query);
// We expect Accumulate after Merge, because it is considered as a write.
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(*query, symbol_table);
CheckPlan(*plan, symbol_table, ExpectScanAll(),
ExpectMerge(on_match, on_create), acc, ExpectProduce());
for (auto &op : on_match) delete op;
on_match.clear();
for (auto &op : on_create) delete op;
on_create.clear();
}
} // namespace } // namespace

View File

@ -649,4 +649,75 @@ TEST(TestSymbolGenerator, OrderBy) {
} }
} }
TEST(TestSymbolGenerator, Merge) {
// Test MATCH (n) MERGE (n)
{
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), MERGE(PATTERN(NODE("n"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError);
}
// Test MATCH (n) -[r]- (m) MERGE (a) -[r :rel]- (b)
{
Dbms dbms;
auto dba = dbms.active();
auto rel = dba->edge_type("rel");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))),
MERGE(PATTERN(NODE("a"), EDGE("r", rel), NODE("b"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError);
}
// Test MERGE (a) -[r]- (b)
{
AstTreeStorage storage;
auto query = QUERY(MERGE(PATTERN(NODE("a"), EDGE("r"), NODE("b"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
// Edge must have a type, since it doesn't we raise.
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
// Test MATCH (n) MERGE (n) -[r :rel]- (m) ON MATCH SET n.prop = 42
// ON CREATE SET m.prop = 42 RETURN r AS r
{
Dbms dbms;
auto dba = dbms.active();
auto rel = dba->edge_type("rel");
auto prop = dba->property("prop");
AstTreeStorage storage;
auto match_n = NODE("n");
auto merge_n = NODE("n");
auto edge_r = EDGE("r", rel);
auto node_m = NODE("m");
auto n_prop = PROPERTY_LOOKUP("n", prop);
auto m_prop = PROPERTY_LOOKUP("m", prop);
auto ident_r = IDENT("r");
auto as_r = AS("r");
auto query = QUERY(MATCH(PATTERN(match_n)),
MERGE(PATTERN(merge_n, edge_r, node_m),
ON_MATCH(SET(n_prop, LITERAL(42))),
ON_CREATE(SET(m_prop, LITERAL(42)))),
RETURN(ident_r, as_r));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for: `n`, `r`, `m` and `AS r`.
EXPECT_EQ(symbol_table.max_position(), 4);
auto n = symbol_table.at(*match_n->identifier_);
EXPECT_EQ(n, symbol_table.at(*merge_n->identifier_));
EXPECT_EQ(n, symbol_table.at(*n_prop->expression_));
auto r = symbol_table.at(*edge_r->identifier_);
EXPECT_NE(r, n);
EXPECT_EQ(r, symbol_table.at(*ident_r));
EXPECT_NE(r, symbol_table.at(*as_r));
auto m = symbol_table.at(*node_m->identifier_);
EXPECT_NE(m, n);
EXPECT_NE(m, r);
EXPECT_NE(m, symbol_table.at(*as_r));
EXPECT_EQ(m, symbol_table.at(*m_prop->expression_));
}
}
} }