Add planning Expand logical operator
Summary: Make LogicalOperator visitable. Add unit tests for logical planner. Add planning Expand logical operators. Test planning edge expansion. Add documentation to planner implementation. Reviewers: florijan, buda, mislav.bradac Reviewed By: mislav.bradac Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D147
This commit is contained in:
parent
b33a654137
commit
bff671af43
@ -8,6 +8,8 @@
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/interpret/interpret.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "utils/visitor/visitable.hpp"
|
||||
#include "utils/visitor/visitor.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
@ -17,7 +19,17 @@ class Cursor {
|
||||
virtual ~Cursor() {}
|
||||
};
|
||||
|
||||
class LogicalOperator {
|
||||
class CreateOp;
|
||||
class ScanAll;
|
||||
class Expand;
|
||||
class NodeFilter;
|
||||
class EdgeFilter;
|
||||
class Produce;
|
||||
|
||||
using LogicalOperatorVisitor =
|
||||
::utils::Visitor<CreateOp, ScanAll, Expand, NodeFilter, EdgeFilter, Produce>;
|
||||
|
||||
class LogicalOperator : public ::utils::Visitable<LogicalOperatorVisitor> {
|
||||
public:
|
||||
auto children() { return children_; };
|
||||
virtual std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor& db) = 0;
|
||||
@ -30,6 +42,7 @@ class LogicalOperator {
|
||||
class CreateOp : public LogicalOperator {
|
||||
public:
|
||||
CreateOp(NodeAtom* node_atom) : node_atom_(node_atom) {}
|
||||
DEFVISITABLE(LogicalOperatorVisitor);
|
||||
|
||||
private:
|
||||
class CreateOpCursor : public Cursor {
|
||||
@ -71,7 +84,8 @@ class CreateOp : public LogicalOperator {
|
||||
|
||||
class ScanAll : public LogicalOperator {
|
||||
public:
|
||||
ScanAll(NodeAtom* node_atom) : node_atom_(node_atom) {}
|
||||
ScanAll(NodeAtom *node_atom) : node_atom_(node_atom) {}
|
||||
DEFVISITABLE(LogicalOperatorVisitor);
|
||||
|
||||
private:
|
||||
class ScanAllCursor : public Cursor {
|
||||
@ -153,6 +167,12 @@ class Expand : public LogicalOperator {
|
||||
node_cycle_(node_cycle),
|
||||
edge_cycle_(edge_cycle) {}
|
||||
|
||||
void Accept(LogicalOperatorVisitor &visitor) override {
|
||||
visitor.Visit(*this);
|
||||
input_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
class ExpandCursor : public Cursor {
|
||||
public:
|
||||
@ -320,6 +340,12 @@ class NodeFilter : public LogicalOperator {
|
||||
NodeAtom* node_atom)
|
||||
: input_(input), input_symbol_(input_symbol), node_atom_(node_atom) {}
|
||||
|
||||
void Accept(LogicalOperatorVisitor &visitor) override {
|
||||
visitor.Visit(*this);
|
||||
input_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
class NodeFilterCursor : public Cursor {
|
||||
public:
|
||||
@ -373,6 +399,12 @@ class EdgeFilter : public LogicalOperator {
|
||||
EdgeAtom* edge_atom)
|
||||
: input_(input), input_symbol_(input_symbol), edge_atom_(edge_atom) {}
|
||||
|
||||
void Accept(LogicalOperatorVisitor &visitor) override {
|
||||
visitor.Visit(*this);
|
||||
input_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
class EdgeFilterCursor : public Cursor {
|
||||
public:
|
||||
@ -428,6 +460,12 @@ class Produce : public LogicalOperator {
|
||||
children_.emplace_back(input);
|
||||
}
|
||||
|
||||
void Accept(LogicalOperatorVisitor &visitor) override {
|
||||
visitor.Visit(*this);
|
||||
input_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor& db) override {
|
||||
return std::make_unique<ProduceCursor>(*this, db);
|
||||
}
|
||||
@ -459,4 +497,5 @@ class Produce : public LogicalOperator {
|
||||
std::shared_ptr<LogicalOperator> input_;
|
||||
std::vector<NamedExpression*> named_expressions_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include "query/frontend/logical/planner.hpp"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "utils/exceptions/not_yet_implemented.hpp"
|
||||
|
||||
@ -11,47 +13,108 @@ static LogicalOperator *GenCreate(
|
||||
Create& create, std::shared_ptr<LogicalOperator> input_op)
|
||||
{
|
||||
if (input_op) {
|
||||
// TODO: Support clauses before CREATE, e.g. `MATCH (n) CREATE (m)`
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
if (create.patterns_.size() != 1) {
|
||||
// TODO: Support creating multiple patterns, e.g. `CREATE (n), (m)`
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
auto &pattern = create.patterns_[0];
|
||||
if (pattern->atoms_.size() != 1) {
|
||||
// TODO: Support creating edges.
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
auto *node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]);
|
||||
debug_assert(node_atom, "First pattern atom is not a node");
|
||||
return new CreateOp(node_atom);
|
||||
}
|
||||
|
||||
// Returns false if the symbol was already bound, otherwise binds it and
|
||||
// returns true.
|
||||
bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol)
|
||||
{
|
||||
auto insertion = bound_symbols.insert(symbol.position_);
|
||||
return insertion.second;
|
||||
}
|
||||
|
||||
LogicalOperator *GenMatch(
|
||||
Match& match,
|
||||
std::shared_ptr<LogicalOperator> input_op,
|
||||
const SymbolTable &symbol_table)
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols)
|
||||
{
|
||||
if (input_op) {
|
||||
// TODO: Support clauses before match.
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
if (match.patterns_.size() != 1) {
|
||||
// TODO: Support matching multiple patterns.
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
auto &pattern = match.patterns_[0];
|
||||
if (pattern->atoms_.size() != 1) {
|
||||
throw NotYetImplemented();
|
||||
debug_assert(!pattern->atoms_.empty(), "Missing atoms in pattern");
|
||||
auto atoms_it = pattern->atoms_.begin();
|
||||
auto last_node = dynamic_cast<NodeAtom*>(*atoms_it++);
|
||||
debug_assert(last_node, "First pattern atom is not a node");
|
||||
// First atom always binds a symbol, and we don't care if it already existed,
|
||||
// because we create a ScanAll which writes that symbol. This may need to
|
||||
// change when we support clauses before match.
|
||||
BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_));
|
||||
LogicalOperator *last_op = new ScanAll(last_node);
|
||||
if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
|
||||
last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
symbol_table.at(*last_node->identifier_),
|
||||
last_node);
|
||||
}
|
||||
auto *node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]);
|
||||
auto *scan_all = new ScanAll(node_atom);
|
||||
if (!node_atom->labels_.empty() || !node_atom->properties_.empty()) {
|
||||
auto &input_symbol = symbol_table.at(*node_atom->identifier_);
|
||||
return new NodeFilter(std::shared_ptr<LogicalOperator>(scan_all), input_symbol,
|
||||
node_atom);
|
||||
EdgeAtom *last_edge = nullptr;
|
||||
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
|
||||
for ( ; atoms_it != pattern->atoms_.end(); ++atoms_it) {
|
||||
if (last_edge) {
|
||||
// Store the symbol from the first node as the input to Expand.
|
||||
auto input_symbol = symbol_table.at(*last_node->identifier_);
|
||||
last_node = dynamic_cast<NodeAtom*>(*atoms_it);
|
||||
debug_assert(last_node, "Expected a node atom in pattern.");
|
||||
// If the expand symbols were already bound, then we need to indicate
|
||||
// this as a cycle. The Expand will then check whether the pattern holds
|
||||
// instead of writing the expansion to symbols.
|
||||
auto node_cycle = false;
|
||||
auto edge_cycle = false;
|
||||
if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) {
|
||||
node_cycle = true;
|
||||
}
|
||||
if (!BindSymbol(bound_symbols, symbol_table.at(*last_edge->identifier_))) {
|
||||
edge_cycle = true;
|
||||
}
|
||||
last_op = new Expand(last_node, last_edge,
|
||||
std::shared_ptr<LogicalOperator>(last_op),
|
||||
input_symbol, node_cycle, edge_cycle);
|
||||
if (!last_edge->edge_types_.empty()) {
|
||||
last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
symbol_table.at(*last_edge->identifier_),
|
||||
last_edge);
|
||||
}
|
||||
if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
|
||||
last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
symbol_table.at(*last_node->identifier_),
|
||||
last_node);
|
||||
}
|
||||
// Don't forget to clear the edge, because we expect the next
|
||||
// (EdgeAtom, NodeAtom) sequence.
|
||||
last_edge = nullptr;
|
||||
} else {
|
||||
last_edge = dynamic_cast<EdgeAtom*>(*atoms_it);
|
||||
debug_assert(last_edge, "Expected an edge atom in pattern.");
|
||||
}
|
||||
}
|
||||
return scan_all;
|
||||
debug_assert(!last_edge, "Edge atom should not end the pattern.");
|
||||
return last_op;
|
||||
}
|
||||
|
||||
Produce *GenReturn(Return& ret, std::shared_ptr<LogicalOperator> input_op)
|
||||
{
|
||||
if (!input_op) {
|
||||
// TODO: Support standalone RETURN clause (e.g. RETURN 2)
|
||||
throw NotYetImplemented();
|
||||
}
|
||||
return new Produce(input_op, ret.named_expressions_);
|
||||
@ -61,12 +124,19 @@ Produce *GenReturn(Return& ret, std::shared_ptr<LogicalOperator> input_op)
|
||||
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
|
||||
Query& query, const SymbolTable &symbol_table)
|
||||
{
|
||||
// TODO: Extract functions and state into a class with methods. Possibly a
|
||||
// visitor or similar to avoid all those dynamic casts.
|
||||
LogicalOperator *input_op = nullptr;
|
||||
// bound_symbols set is used to differentiate cycles in pattern matching, so
|
||||
// that the operator can be correctly initialized whether to read the symbol
|
||||
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
|
||||
// `n`, but the latter `n` would only read the already written information.
|
||||
std::unordered_set<int> bound_symbols;
|
||||
for (auto &clause : query.clauses_) {
|
||||
auto *clause_ptr = clause;
|
||||
if (auto *match = dynamic_cast<Match*>(clause_ptr)) {
|
||||
input_op = GenMatch(*match, std::shared_ptr<LogicalOperator>(input_op),
|
||||
symbol_table);
|
||||
symbol_table, bound_symbols);
|
||||
} else if (auto *ret = dynamic_cast<Return*>(clause_ptr)) {
|
||||
input_op = GenReturn(*ret, std::shared_ptr<LogicalOperator>(input_op));
|
||||
} else if (auto *create = dynamic_cast<Create*>(clause_ptr)) {
|
||||
|
165
tests/unit/query_planner.cpp
Normal file
165
tests/unit/query_planner.cpp
Normal file
@ -0,0 +1,165 @@
|
||||
#include <list>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/logical/operator.hpp"
|
||||
#include "query/frontend/logical/planner.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "query/frontend/semantic/symbol_generator.hpp"
|
||||
|
||||
using namespace query;
|
||||
|
||||
namespace {
|
||||
|
||||
class PlanChecker : public LogicalOperatorVisitor {
|
||||
public:
|
||||
using LogicalOperatorVisitor::Visit;
|
||||
using LogicalOperatorVisitor::PostVisit;
|
||||
|
||||
PlanChecker(std::list<size_t> types) : types_(types) {}
|
||||
|
||||
void Visit(CreateOp &op) override { AssertType(op); }
|
||||
void Visit(ScanAll &op) override { AssertType(op); }
|
||||
void Visit(Expand &op) override { AssertType(op); }
|
||||
void Visit(NodeFilter &op) override { AssertType(op); }
|
||||
void Visit(EdgeFilter &op) override { AssertType(op); }
|
||||
void Visit(Produce &op) override { AssertType(op); }
|
||||
|
||||
private:
|
||||
void AssertType(const LogicalOperator &op) {
|
||||
ASSERT_FALSE(types_.empty());
|
||||
ASSERT_EQ(types_.back(), typeid(op).hash_code());
|
||||
types_.pop_back();
|
||||
}
|
||||
std::list<size_t> types_;
|
||||
};
|
||||
|
||||
// Returns a `(name1) -[name2]- (name3) ...` pattern.
|
||||
auto GetPattern(AstTreeStorage &storage, std::vector<std::string> names) {
|
||||
bool is_node{true};
|
||||
auto pattern = storage.Create<Pattern>();
|
||||
for (auto &name : names) {
|
||||
PatternAtom *atom;
|
||||
auto identifier = storage.Create<Identifier>(name);
|
||||
if (is_node) {
|
||||
atom = storage.Create<NodeAtom>(identifier);
|
||||
} else {
|
||||
atom = storage.Create<EdgeAtom>(identifier);
|
||||
}
|
||||
pattern->atoms_.emplace_back(atom);
|
||||
is_node = !is_node;
|
||||
}
|
||||
return pattern;
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchNodeReturn) {
|
||||
// Test MATCH (n) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto match = storage.Create<Match>();
|
||||
match->patterns_.emplace_back(GetPattern(storage, {"n"}));
|
||||
auto query = storage.query();
|
||||
query->clauses_.emplace_back(match);
|
||||
auto named_expr = storage.Create<NamedExpression>();
|
||||
named_expr->name_ = "n";
|
||||
named_expr->expression_ = storage.Create<Identifier>("n");
|
||||
auto ret = storage.Create<Return>();
|
||||
ret->named_expressions_.emplace_back(named_expr);
|
||||
query->clauses_.emplace_back(ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
std::list<size_t> expected_types;
|
||||
expected_types.emplace_back(typeid(ScanAll).hash_code());
|
||||
expected_types.emplace_back(typeid(Produce).hash_code());
|
||||
PlanChecker plan_checker(expected_types);
|
||||
plan->Accept(plan_checker);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, CreateNodeReturn) {
|
||||
// Test CREATE (n) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto create = storage.Create<Create>();
|
||||
create->patterns_.emplace_back(GetPattern(storage, {"n"}));
|
||||
auto query = storage.query();
|
||||
query->clauses_.emplace_back(create);
|
||||
auto named_expr = storage.Create<NamedExpression>();
|
||||
named_expr->name_ = "n";
|
||||
named_expr->expression_ = storage.Create<Identifier>("n");
|
||||
auto ret = storage.Create<Return>();
|
||||
ret->named_expressions_.emplace_back(named_expr);
|
||||
query->clauses_.emplace_back(ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
std::list<size_t> expected_types;
|
||||
expected_types.emplace_back(typeid(CreateOp).hash_code());
|
||||
expected_types.emplace_back(typeid(Produce).hash_code());
|
||||
PlanChecker plan_checker(expected_types);
|
||||
plan->Accept(plan_checker);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchLabeledNodes) {
|
||||
// Test MATCH (n :label) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto pattern = storage.Create<Pattern>();
|
||||
auto node_atom = storage.Create<NodeAtom>(storage.Create<Identifier>("n"));
|
||||
std::string label("label");
|
||||
node_atom->labels_.emplace_back(&label);
|
||||
pattern->atoms_.emplace_back(node_atom);
|
||||
auto match = storage.Create<Match>();
|
||||
match->patterns_.emplace_back(pattern);
|
||||
auto query = storage.query();
|
||||
query->clauses_.emplace_back(match);
|
||||
auto named_expr = storage.Create<NamedExpression>();
|
||||
named_expr->name_ = "n";
|
||||
named_expr->expression_ = storage.Create<Identifier>("n");
|
||||
auto ret = storage.Create<Return>();
|
||||
ret->named_expressions_.emplace_back(named_expr);
|
||||
query->clauses_.emplace_back(ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
std::list<size_t> expected_types;
|
||||
expected_types.emplace_back(typeid(ScanAll).hash_code());
|
||||
expected_types.emplace_back(typeid(NodeFilter).hash_code());
|
||||
expected_types.emplace_back(typeid(Produce).hash_code());
|
||||
PlanChecker plan_checker(expected_types);
|
||||
plan->Accept(plan_checker);
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchPathReturn) {
|
||||
// Test MATCH (n) -[r :relationship]- (m) RETURN n AS n
|
||||
AstTreeStorage storage;
|
||||
auto match = storage.Create<Match>();
|
||||
auto pattern = GetPattern(storage, {"n", "r", "m"});
|
||||
match->patterns_.emplace_back(pattern);
|
||||
auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
|
||||
std::string relationship("relationship");
|
||||
edge_atom->edge_types_.emplace_back(&relationship);
|
||||
auto query = storage.query();
|
||||
query->clauses_.emplace_back(match);
|
||||
auto named_expr = storage.Create<NamedExpression>();
|
||||
named_expr->name_ = "n";
|
||||
named_expr->expression_ = storage.Create<Identifier>("n");
|
||||
auto ret = storage.Create<Return>();
|
||||
ret->named_expressions_.emplace_back(named_expr);
|
||||
query->clauses_.emplace_back(ret);
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
query->Accept(symbol_generator);
|
||||
auto plan = MakeLogicalPlan(*query, symbol_table);
|
||||
std::list<size_t> expected_types;
|
||||
expected_types.emplace_back(typeid(ScanAll).hash_code());
|
||||
expected_types.emplace_back(typeid(Expand).hash_code());
|
||||
expected_types.emplace_back(typeid(EdgeFilter).hash_code());
|
||||
expected_types.emplace_back(typeid(Produce).hash_code());
|
||||
PlanChecker plan_checker(expected_types);
|
||||
plan->Accept(plan_checker);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user