Plan Merge operator

Summary:
Check symbols in Merge.
Support MERGE macro in query tests.
Test SymbolGenerator with MERGE.
Test planning Merge.

Reviewers: florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D317
This commit is contained in:
Teon Banek 2017-04-26 13:49:41 +02:00
parent 0fa4555cad
commit 8fa574026e
9 changed files with 274 additions and 66 deletions

View File

@ -111,6 +111,9 @@ bool SymbolGenerator::PreVisit(With &with) {
void SymbolGenerator::Visit(Where &) { scope_.in_where = true; }
void SymbolGenerator::PostVisit(Where &) { scope_.in_where = false; }
void SymbolGenerator::Visit(Merge &) { scope_.in_merge = true; }
void SymbolGenerator::PostVisit(Merge &) { scope_.in_merge = false; }
// Expressions
void SymbolGenerator::Visit(Identifier &ident) {
@ -178,7 +181,7 @@ void SymbolGenerator::PostVisit(Aggregation &aggr) {
void SymbolGenerator::Visit(Pattern &pattern) {
scope_.in_pattern = true;
if (scope_.in_create && pattern.atoms_.size() == 1U) {
if ((scope_.in_create || scope_.in_merge) && pattern.atoms_.size() == 1U) {
debug_assert(dynamic_cast<NodeAtom *>(pattern.atoms_[0]),
"Expected a single NodeAtom in Pattern");
scope_.in_create_node = true;
@ -213,14 +216,15 @@ void SymbolGenerator::PostVisit(NodeAtom &node_atom) {
void SymbolGenerator::Visit(EdgeAtom &edge_atom) {
scope_.in_edge_atom = true;
if (scope_.in_create) {
if (scope_.in_create || scope_.in_merge) {
scope_.in_create_edge = true;
if (edge_atom.edge_types_.size() != 1U) {
throw SemanticException(
"A single relationship type must be specified "
"when creating an edge.");
}
if (edge_atom.direction_ == EdgeAtom::Direction::BOTH) {
if (scope_.in_create && // Merge allows bidirectionality
edge_atom.direction_ == EdgeAtom::Direction::BOTH) {
throw SemanticException(
"Bidirectional relationship are not supported "
"when creating an edge");

View File

@ -31,6 +31,8 @@ class SymbolGenerator : public TreeVisitorBase {
bool PreVisit(With &) override;
void Visit(Where &) override;
void PostVisit(Where &) override;
void Visit(Merge &) override;
void PostVisit(Merge &) override;
// Expressions
void Visit(Identifier &) override;
@ -50,11 +52,14 @@ class SymbolGenerator : public TreeVisitorBase {
// names to symbols.
struct Scope {
bool in_pattern{false};
bool in_merge{false};
bool in_create{false};
// in_create_node is true if we are creating *only* a node. Therefore, it
// is *not* equivalent to in_create && in_node_atom.
// in_create_node is true if we are creating or merging *only* a node.
// Therefore, it is *not* equivalent to (in_create || in_merge) &&
// in_node_atom.
bool in_create_node{false};
// True if creating an edge; shortcut for in_create && in_edge_atom.
// True if creating an edge;
// shortcut for (in_create || in_merge) && in_edge_atom.
bool in_create_edge{false};
bool in_node_atom{false};
bool in_edge_atom{false};

View File

@ -77,7 +77,8 @@ void Interpret(const std::string &query, GraphDbAccessor &db_accessor,
dynamic_cast<plan::SetLabels *>(logical_plan.get()) ||
dynamic_cast<plan::RemoveProperty *>(logical_plan.get()) ||
dynamic_cast<plan::RemoveLabels *>(logical_plan.get()) ||
dynamic_cast<plan::Delete *>(logical_plan.get())) {
dynamic_cast<plan::Delete *>(logical_plan.get()) ||
dynamic_cast<plan::Merge *>(logical_plan.get())) {
stream.Header(header);
auto cursor = logical_plan->MakeCursor(db_accessor);
while (cursor->Pull(frame, symbol_table)) continue;

View File

@ -85,7 +85,7 @@ CreateExpand::CreateExpand(const NodeAtom *node_atom, const EdgeAtom *edge_atom,
Symbol input_symbol, bool existing_node)
: node_atom_(node_atom),
edge_atom_(edge_atom),
input_(input),
input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol),
existing_node_(existing_node) {}
@ -223,7 +223,7 @@ Expand::Expand(const NodeAtom *node_atom, const EdgeAtom *edge_atom,
GraphView graph_view)
: node_atom_(node_atom),
edge_atom_(edge_atom),
input_(input),
input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol),
existing_node_(existing_node),
existing_edge_(existing_edge),

View File

@ -185,7 +185,7 @@ class CreateExpand : public LogicalOperator {
* @param node_atom @c NodeAtom at the end of the edge. Used to create a node,
* unless it refers to an existing one.
* @param edge_atom @c EdgeAtom with information for the edge to be created.
* @param input Required. Previous @c LogicalOperator which will be pulled.
* @param input Optional. Previous @c LogicalOperator which will be pulled.
* For each successful @c Cursor::Pull, this operator will create an
* expansion.
* @param input_symbol @c Symbol for the node at the start of the edge.
@ -321,7 +321,7 @@ class Expand : public LogicalOperator {
* identifier is used, labels and properties are ignored.
* @param edge_atom Describes the edge to be expanded. Identifier
* and direction are used, edge type and properties are ignored.
* @param input LogicalOperation that preceeds this one.
* @param input Optional LogicalOperator that preceeds this one.
* @param input_symbol Symbol that points to a VertexAccessor
* in the Frame that expansion should emanate from.
* @param existing_node If or not the node to be expanded is already
@ -1165,6 +1165,10 @@ class Merge : public LogicalOperator {
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
auto input() const { return input_; }
auto merge_match() const { return merge_match_; }
auto merge_create() const { return merge_create_; }
private:
const std::shared_ptr<LogicalOperator> input_;
const std::shared_ptr<LogicalOperator> merge_match_;

View File

@ -7,8 +7,7 @@
#include "query/frontend/ast/ast.hpp"
#include "utils/exceptions.hpp"
namespace query {
namespace plan {
namespace query::plan {
namespace {
@ -105,13 +104,15 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols,
std::vector<Symbol> &edge_symbols) {
std::vector<Symbol> &edge_symbols,
GraphView graph_view = GraphView::OLD) {
auto base = [&](NodeAtom *node) {
LogicalOperator *last_op = input_op;
// If the first atom binds a symbol, we generate a ScanAll which writes it.
// Otherwise, someone else generates it (e.g. a previous ScanAll).
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op));
last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op),
graph_view);
}
// Even though we may skip generating ScanAll, we still want to add a filter
// in case this atom adds more labels/properties for filtering.
@ -126,20 +127,21 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
// Store the symbol from the first node as the input to Expand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand symbols were already bound, then we need to indicate
// this as a cycle. The Expand will then check whether the pattern holds
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
auto node_cycle = false;
auto edge_cycle = false;
auto existing_node = false;
auto existing_edge = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
node_cycle = true;
existing_node = true;
}
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
if (!BindSymbol(bound_symbols, edge_symbol)) {
edge_cycle = true;
existing_edge = true;
}
last_op = new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
input_symbol, node_cycle, edge_cycle);
if (!edge_cycle) {
last_op =
new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
input_symbol, existing_node, existing_edge, graph_view);
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to different
// edges).
if (!edge_symbols.empty()) {
@ -455,6 +457,32 @@ LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
return nullptr;
}
auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
// Copy the bound symbol set, because we don't want to use the updated version
// when generating the create part.
std::unordered_set<int> bound_symbols_copy(bound_symbols);
std::vector<Symbol> edge_symbols;
auto on_match =
GenMatchForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols_copy, edge_symbols, GraphView::NEW);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols);
for (auto &set : merge.on_create_) {
on_create = HandleWriteClause(set, on_create, symbol_table, bound_symbols);
debug_assert(on_create, "Expected SET in MERGE ... ON CREATE");
}
for (auto &set : merge.on_match_) {
on_match = HandleWriteClause(set, on_match, symbol_table, bound_symbols);
debug_assert(on_match, "Expected SET in MERGE ... ON MATCH");
}
return new plan::Merge(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(on_match),
std::shared_ptr<LogicalOperator>(on_create));
}
} // namespace
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
@ -464,7 +492,7 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<int> bound_symbols;
// Set to true if a query command performs a writes to the database.
// Set to true if a query command writes to the database.
bool is_write = false;
LogicalOperator *input_op = nullptr;
for (auto &clause : query.clauses_) {
@ -473,6 +501,11 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
input_op = GenMatch(*match, input_op, symbol_table, bound_symbols);
} else if (auto *ret = dynamic_cast<Return *>(clause)) {
input_op = GenReturn(*ret, input_op, symbol_table, is_write);
} else if (auto *merge = dynamic_cast<query::Merge *>(clause)) {
input_op = GenMerge(*merge, input_op, symbol_table, bound_symbols);
// Treat MERGE clause as write, because we do not know if it will create
// anything.
is_write = true;
} else if (auto *with = dynamic_cast<query::With *>(clause)) {
input_op =
GenWith(*with, input_op, symbol_table, is_write, bound_symbols);
@ -489,5 +522,4 @@ std::unique_ptr<LogicalOperator> MakeLogicalPlan(
return std::unique_ptr<LogicalOperator>(input_op);
}
} // namespace plan
} // namespace query
} // namespace query::plan

View File

@ -28,8 +28,8 @@ namespace query {
namespace test_common {
// Custom types for ORDER BY, SKIP and LIMIT and expressions, so that they can
// be used to resolve function calls.
// Custom types for ORDER BY, SKIP, LIMIT, ON MATCH and ON CREATE expressions,
// so that they can be used to resolve function calls.
struct OrderBy {
std::vector<std::pair<Ordering, Expression *>> expressions;
};
@ -39,6 +39,12 @@ struct Skip {
struct Limit {
Expression *expression = nullptr;
};
struct OnMatch {
std::vector<Clause *> set;
};
struct OnCreate {
std::vector<Clause *> set;
};
// Helper functions for filling the OrderBy with expressions.
auto FillOrderBy(OrderBy &order_by, Expression *expression,
@ -301,6 +307,26 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
return storage.Create<RemoveLabels>(storage.Create<Identifier>(name), labels);
}
///
/// Create a Merge clause for given Pattern with optional OnMatch and OnCreate
/// parts.
///
auto GetMerge(AstTreeStorage &storage, Pattern *pattern,
OnCreate on_create = OnCreate{}) {
auto *merge = storage.Create<query::Merge>();
merge->pattern_ = pattern;
merge->on_create_ = on_create.set;
return merge;
}
auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
OnCreate on_create = OnCreate{}) {
auto *merge = storage.Create<query::Merge>();
merge->pattern_ = pattern;
merge->on_match_ = on_match.set;
merge->on_create_ = on_create.set;
return merge;
}
} // namespace test_common
} // namespace query
@ -346,6 +372,15 @@ auto GetRemove(AstTreeStorage &storage, const std::string &name,
query::test_common::GetDelete(storage, {__VA_ARGS__}, true)
#define SET(...) query::test_common::GetSet(storage, __VA_ARGS__)
#define REMOVE(...) query::test_common::GetRemove(storage, __VA_ARGS__)
#define MERGE(...) query::test_common::GetMerge(storage, __VA_ARGS__)
#define ON_MATCH(...) \
query::test_common::OnMatch { \
std::vector<query::Clause *> { __VA_ARGS__ } \
}
#define ON_CREATE(...) \
query::test_common::OnCreate { \
std::vector<query::Clause *> { __VA_ARGS__ } \
}
#define QUERY(...) query::test_common::GetQuery(storage, __VA_ARGS__)
// Various operators
#define ADD(expr1, expr2) \

View File

@ -29,6 +29,57 @@ class BaseOpChecker {
virtual void CheckOp(LogicalOperator &, const SymbolTable &) = 0;
};
class PlanChecker : public LogicalOperatorVisitor {
public:
using LogicalOperatorVisitor::PreVisit;
using LogicalOperatorVisitor::Visit;
using LogicalOperatorVisitor::PostVisit;
PlanChecker(const std::list<BaseOpChecker *> &checkers,
const SymbolTable &symbol_table)
: checkers_(checkers), symbol_table_(symbol_table) {}
void Visit(CreateNode &op) override { CheckOp(op); }
void Visit(CreateExpand &op) override { CheckOp(op); }
void Visit(Delete &op) override { CheckOp(op); }
void Visit(ScanAll &op) override { CheckOp(op); }
void Visit(Expand &op) override { CheckOp(op); }
void Visit(NodeFilter &op) override { CheckOp(op); }
void Visit(EdgeFilter &op) override { CheckOp(op); }
void Visit(Filter &op) override { CheckOp(op); }
void Visit(Produce &op) override { CheckOp(op); }
void Visit(SetProperty &op) override { CheckOp(op); }
void Visit(SetProperties &op) override { CheckOp(op); }
void Visit(SetLabels &op) override { CheckOp(op); }
void Visit(RemoveProperty &op) override { CheckOp(op); }
void Visit(RemoveLabels &op) override { CheckOp(op); }
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
CheckOp(op);
}
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
void Visit(OrderBy &op) override { CheckOp(op); }
bool PreVisit(Merge &op) override {
CheckOp(op);
op.input()->Accept(*this);
return false;
}
std::list<BaseOpChecker *> checkers_;
private:
void CheckOp(LogicalOperator &op) {
ASSERT_FALSE(checkers_.empty());
checkers_.back()->CheckOp(op, symbol_table_);
checkers_.pop_back();
}
const SymbolTable &symbol_table_;
};
template <class TOp>
class OpChecker : public BaseOpChecker {
public:
@ -103,49 +154,22 @@ class ExpectAggregate : public OpChecker<Aggregate> {
const std::unordered_set<query::Expression *> group_by_;
};
class PlanChecker : public LogicalOperatorVisitor {
class ExpectMerge : public OpChecker<Merge> {
public:
using LogicalOperatorVisitor::Visit;
using LogicalOperatorVisitor::PostVisit;
ExpectMerge(const std::list<BaseOpChecker *> &on_match,
const std::list<BaseOpChecker *> &on_create)
: on_match_(on_match), on_create_(on_create) {}
PlanChecker(const std::list<BaseOpChecker *> &checkers,
const SymbolTable &symbol_table)
: checkers_(checkers), symbol_table_(symbol_table) {}
void Visit(CreateNode &op) override { CheckOp(op); }
void Visit(CreateExpand &op) override { CheckOp(op); }
void Visit(Delete &op) override { CheckOp(op); }
void Visit(ScanAll &op) override { CheckOp(op); }
void Visit(Expand &op) override { CheckOp(op); }
void Visit(NodeFilter &op) override { CheckOp(op); }
void Visit(EdgeFilter &op) override { CheckOp(op); }
void Visit(Filter &op) override { CheckOp(op); }
void Visit(Produce &op) override { CheckOp(op); }
void Visit(SetProperty &op) override { CheckOp(op); }
void Visit(SetProperties &op) override { CheckOp(op); }
void Visit(SetLabels &op) override { CheckOp(op); }
void Visit(RemoveProperty &op) override { CheckOp(op); }
void Visit(RemoveLabels &op) override { CheckOp(op); }
void Visit(ExpandUniquenessFilter<VertexAccessor> &op) override {
CheckOp(op);
void ExpectOp(Merge &merge, const SymbolTable &symbol_table) override {
PlanChecker check_match(on_match_, symbol_table);
merge.merge_match()->Accept(check_match);
PlanChecker check_create(on_create_, symbol_table);
merge.merge_create()->Accept(check_create);
}
void Visit(ExpandUniquenessFilter<EdgeAccessor> &op) override { CheckOp(op); }
void Visit(Accumulate &op) override { CheckOp(op); }
void Visit(Aggregate &op) override { CheckOp(op); }
void Visit(Skip &op) override { CheckOp(op); }
void Visit(Limit &op) override { CheckOp(op); }
void Visit(OrderBy &op) override { CheckOp(op); }
std::list<BaseOpChecker *> checkers_;
private:
void CheckOp(LogicalOperator &op) {
ASSERT_FALSE(checkers_.empty());
checkers_.back()->CheckOp(op, symbol_table_);
checkers_.pop_back();
}
const SymbolTable &symbol_table_;
const std::list<BaseOpChecker *> &on_match_;
const std::list<BaseOpChecker *> &on_create_;
};
auto MakeSymbolTable(query::Query &query) {
@ -571,4 +595,36 @@ TEST(TestLogicalPlanner, ReturnAddSumCountOrderBy) {
CheckPlan(*query, aggr, ExpectProduce(), ExpectOrderBy());
}
TEST(TestLogicalPlanner, MatchMerge) {
// Test MATCH (n) MERGE (n) -[r :r]- (m)
// ON MATCH SET n.prop = 42 ON CREATE SET m = n
// RETURN n AS n
Dbms dbms;
auto dba = dbms.active();
auto r_type = dba->edge_type("r");
auto prop = dba->property("prop");
AstTreeStorage storage;
auto ident_n = IDENT("n");
auto query =
QUERY(MATCH(PATTERN(NODE("n"))),
MERGE(PATTERN(NODE("n"), EDGE("r", r_type), NODE("m")),
ON_MATCH(SET(PROPERTY_LOOKUP("n", prop), LITERAL(42))),
ON_CREATE(SET("m", IDENT("n")))),
RETURN(ident_n, AS("n")));
std::list<BaseOpChecker *> on_match{
new ExpectExpand(), new ExpectEdgeFilter(), new ExpectSetProperty()};
std::list<BaseOpChecker *> on_create{new ExpectCreateExpand(),
new ExpectSetProperties()};
auto symbol_table = MakeSymbolTable(*query);
// We expect Accumulate after Merge, because it is considered as a write.
auto acc = ExpectAccumulate({symbol_table.at(*ident_n)});
auto plan = MakeLogicalPlan(*query, symbol_table);
CheckPlan(*plan, symbol_table, ExpectScanAll(),
ExpectMerge(on_match, on_create), acc, ExpectProduce());
for (auto &op : on_match) delete op;
on_match.clear();
for (auto &op : on_create) delete op;
on_create.clear();
}
} // namespace

View File

@ -649,4 +649,75 @@ TEST(TestSymbolGenerator, OrderBy) {
}
}
TEST(TestSymbolGenerator, Merge) {
// Test MATCH (n) MERGE (n)
{
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))), MERGE(PATTERN(NODE("n"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError);
}
// Test MATCH (n) -[r]- (m) MERGE (a) -[r :rel]- (b)
{
Dbms dbms;
auto dba = dbms.active();
auto rel = dba->edge_type("rel");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))),
MERGE(PATTERN(NODE("a"), EDGE("r", rel), NODE("b"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
EXPECT_THROW(query->Accept(symbol_generator), RedeclareVariableError);
}
// Test MERGE (a) -[r]- (b)
{
AstTreeStorage storage;
auto query = QUERY(MERGE(PATTERN(NODE("a"), EDGE("r"), NODE("b"))));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
// Edge must have a type, since it doesn't we raise.
EXPECT_THROW(query->Accept(symbol_generator), SemanticException);
}
// Test MATCH (n) MERGE (n) -[r :rel]- (m) ON MATCH SET n.prop = 42
// ON CREATE SET m.prop = 42 RETURN r AS r
{
Dbms dbms;
auto dba = dbms.active();
auto rel = dba->edge_type("rel");
auto prop = dba->property("prop");
AstTreeStorage storage;
auto match_n = NODE("n");
auto merge_n = NODE("n");
auto edge_r = EDGE("r", rel);
auto node_m = NODE("m");
auto n_prop = PROPERTY_LOOKUP("n", prop);
auto m_prop = PROPERTY_LOOKUP("m", prop);
auto ident_r = IDENT("r");
auto as_r = AS("r");
auto query = QUERY(MATCH(PATTERN(match_n)),
MERGE(PATTERN(merge_n, edge_r, node_m),
ON_MATCH(SET(n_prop, LITERAL(42))),
ON_CREATE(SET(m_prop, LITERAL(42)))),
RETURN(ident_r, as_r));
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
// Symbols for: `n`, `r`, `m` and `AS r`.
EXPECT_EQ(symbol_table.max_position(), 4);
auto n = symbol_table.at(*match_n->identifier_);
EXPECT_EQ(n, symbol_table.at(*merge_n->identifier_));
EXPECT_EQ(n, symbol_table.at(*n_prop->expression_));
auto r = symbol_table.at(*edge_r->identifier_);
EXPECT_NE(r, n);
EXPECT_EQ(r, symbol_table.at(*ident_r));
EXPECT_NE(r, symbol_table.at(*as_r));
auto m = symbol_table.at(*node_m->identifier_);
EXPECT_NE(m, n);
EXPECT_NE(m, r);
EXPECT_NE(m, symbol_table.at(*as_r));
EXPECT_EQ(m, symbol_table.at(*m_prop->expression_));
}
}
}