Add planning Expand logical operator

Summary:
Make LogicalOperator visitable.
Add unit tests for logical planner.
Add planning Expand logical operators.
Test planning edge expansion.
Add documentation to planner implementation.

Reviewers: florijan, buda, mislav.bradac

Reviewed By: mislav.bradac

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D147
This commit is contained in:
Teon Banek 2017-03-21 10:30:14 +01:00
parent b33a654137
commit bff671af43
3 changed files with 287 additions and 13 deletions

View File

@ -8,6 +8,8 @@
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "utils/visitor/visitable.hpp"
#include "utils/visitor/visitor.hpp"
namespace query {
@ -17,7 +19,17 @@ class Cursor {
virtual ~Cursor() {}
};
class LogicalOperator {
class CreateOp;
class ScanAll;
class Expand;
class NodeFilter;
class EdgeFilter;
class Produce;
using LogicalOperatorVisitor =
::utils::Visitor<CreateOp, ScanAll, Expand, NodeFilter, EdgeFilter, Produce>;
class LogicalOperator : public ::utils::Visitable<LogicalOperatorVisitor> {
public:
auto children() { return children_; };
virtual std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor& db) = 0;
@ -30,6 +42,7 @@ class LogicalOperator {
class CreateOp : public LogicalOperator {
public:
CreateOp(NodeAtom* node_atom) : node_atom_(node_atom) {}
DEFVISITABLE(LogicalOperatorVisitor);
private:
class CreateOpCursor : public Cursor {
@ -71,7 +84,8 @@ class CreateOp : public LogicalOperator {
class ScanAll : public LogicalOperator {
public:
ScanAll(NodeAtom* node_atom) : node_atom_(node_atom) {}
ScanAll(NodeAtom *node_atom) : node_atom_(node_atom) {}
DEFVISITABLE(LogicalOperatorVisitor);
private:
class ScanAllCursor : public Cursor {
@ -153,6 +167,12 @@ class Expand : public LogicalOperator {
node_cycle_(node_cycle),
edge_cycle_(edge_cycle) {}
void Accept(LogicalOperatorVisitor &visitor) override {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
private:
class ExpandCursor : public Cursor {
public:
@ -320,6 +340,12 @@ class NodeFilter : public LogicalOperator {
NodeAtom* node_atom)
: input_(input), input_symbol_(input_symbol), node_atom_(node_atom) {}
void Accept(LogicalOperatorVisitor &visitor) override {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
private:
class NodeFilterCursor : public Cursor {
public:
@ -373,6 +399,12 @@ class EdgeFilter : public LogicalOperator {
EdgeAtom* edge_atom)
: input_(input), input_symbol_(input_symbol), edge_atom_(edge_atom) {}
void Accept(LogicalOperatorVisitor &visitor) override {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
private:
class EdgeFilterCursor : public Cursor {
public:
@ -428,6 +460,12 @@ class Produce : public LogicalOperator {
children_.emplace_back(input);
}
void Accept(LogicalOperatorVisitor &visitor) override {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor& db) override {
return std::make_unique<ProduceCursor>(*this, db);
}
@ -459,4 +497,5 @@ class Produce : public LogicalOperator {
std::shared_ptr<LogicalOperator> input_;
std::vector<NamedExpression*> named_expressions_;
};
}

View File

@ -1,5 +1,7 @@
#include "query/frontend/logical/planner.hpp"
#include <unordered_set>
#include "query/frontend/ast/ast.hpp"
#include "utils/exceptions/not_yet_implemented.hpp"
@ -11,47 +13,108 @@ static LogicalOperator *GenCreate(
Create& create, std::shared_ptr<LogicalOperator> input_op)
{
if (input_op) {
// TODO: Support clauses before CREATE, e.g. `MATCH (n) CREATE (m)`
throw NotYetImplemented();
}
if (create.patterns_.size() != 1) {
// TODO: Support creating multiple patterns, e.g. `CREATE (n), (m)`
throw NotYetImplemented();
}
auto &pattern = create.patterns_[0];
if (pattern->atoms_.size() != 1) {
// TODO: Support creating edges.
throw NotYetImplemented();
}
auto *node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]);
debug_assert(node_atom, "First pattern atom is not a node");
return new CreateOp(node_atom);
}
// Returns false if the symbol was already bound, otherwise binds it and
// returns true.
bool BindSymbol(std::unordered_set<int> &bound_symbols, const Symbol &symbol)
{
auto insertion = bound_symbols.insert(symbol.position_);
return insertion.second;
}
LogicalOperator *GenMatch(
Match& match,
std::shared_ptr<LogicalOperator> input_op,
const SymbolTable &symbol_table)
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols)
{
if (input_op) {
// TODO: Support clauses before match.
throw NotYetImplemented();
}
if (match.patterns_.size() != 1) {
// TODO: Support matching multiple patterns.
throw NotYetImplemented();
}
auto &pattern = match.patterns_[0];
if (pattern->atoms_.size() != 1) {
throw NotYetImplemented();
debug_assert(!pattern->atoms_.empty(), "Missing atoms in pattern");
auto atoms_it = pattern->atoms_.begin();
auto last_node = dynamic_cast<NodeAtom*>(*atoms_it++);
debug_assert(last_node, "First pattern atom is not a node");
// First atom always binds a symbol, and we don't care if it already existed,
// because we create a ScanAll which writes that symbol. This may need to
// change when we support clauses before match.
BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_));
LogicalOperator *last_op = new ScanAll(last_node);
if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
symbol_table.at(*last_node->identifier_),
last_node);
}
auto *node_atom = dynamic_cast<NodeAtom*>(pattern->atoms_[0]);
auto *scan_all = new ScanAll(node_atom);
if (!node_atom->labels_.empty() || !node_atom->properties_.empty()) {
auto &input_symbol = symbol_table.at(*node_atom->identifier_);
return new NodeFilter(std::shared_ptr<LogicalOperator>(scan_all), input_symbol,
node_atom);
EdgeAtom *last_edge = nullptr;
// Remaining atoms need to follow sequentially as (EdgeAtom, NodeAtom)*
for ( ; atoms_it != pattern->atoms_.end(); ++atoms_it) {
if (last_edge) {
// Store the symbol from the first node as the input to Expand.
auto input_symbol = symbol_table.at(*last_node->identifier_);
last_node = dynamic_cast<NodeAtom*>(*atoms_it);
debug_assert(last_node, "Expected a node atom in pattern.");
// If the expand symbols were already bound, then we need to indicate
// this as a cycle. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
auto node_cycle = false;
auto edge_cycle = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*last_node->identifier_))) {
node_cycle = true;
}
if (!BindSymbol(bound_symbols, symbol_table.at(*last_edge->identifier_))) {
edge_cycle = true;
}
last_op = new Expand(last_node, last_edge,
std::shared_ptr<LogicalOperator>(last_op),
input_symbol, node_cycle, edge_cycle);
if (!last_edge->edge_types_.empty()) {
last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op),
symbol_table.at(*last_edge->identifier_),
last_edge);
}
if (!last_node->labels_.empty() || !last_node->properties_.empty()) {
last_op = new NodeFilter(std::shared_ptr<LogicalOperator>(last_op),
symbol_table.at(*last_node->identifier_),
last_node);
}
// Don't forget to clear the edge, because we expect the next
// (EdgeAtom, NodeAtom) sequence.
last_edge = nullptr;
} else {
last_edge = dynamic_cast<EdgeAtom*>(*atoms_it);
debug_assert(last_edge, "Expected an edge atom in pattern.");
}
}
return scan_all;
debug_assert(!last_edge, "Edge atom should not end the pattern.");
return last_op;
}
Produce *GenReturn(Return& ret, std::shared_ptr<LogicalOperator> input_op)
{
if (!input_op) {
// TODO: Support standalone RETURN clause (e.g. RETURN 2)
throw NotYetImplemented();
}
return new Produce(input_op, ret.named_expressions_);
@ -61,12 +124,19 @@ Produce *GenReturn(Return& ret, std::shared_ptr<LogicalOperator> input_op)
std::unique_ptr<LogicalOperator> MakeLogicalPlan(
Query& query, const SymbolTable &symbol_table)
{
// TODO: Extract functions and state into a class with methods. Possibly a
// visitor or similar to avoid all those dynamic casts.
LogicalOperator *input_op = nullptr;
// bound_symbols set is used to differentiate cycles in pattern matching, so
// that the operator can be correctly initialized whether to read the symbol
// or write it. E.g. `MATCH (n) -[r]- (n)` would bind (and write) the first
// `n`, but the latter `n` would only read the already written information.
std::unordered_set<int> bound_symbols;
for (auto &clause : query.clauses_) {
auto *clause_ptr = clause;
if (auto *match = dynamic_cast<Match*>(clause_ptr)) {
input_op = GenMatch(*match, std::shared_ptr<LogicalOperator>(input_op),
symbol_table);
symbol_table, bound_symbols);
} else if (auto *ret = dynamic_cast<Return*>(clause_ptr)) {
input_op = GenReturn(*ret, std::shared_ptr<LogicalOperator>(input_op));
} else if (auto *create = dynamic_cast<Create*>(clause_ptr)) {

View File

@ -0,0 +1,165 @@
#include <list>
#include <typeinfo>
#include "gtest/gtest.h"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/logical/operator.hpp"
#include "query/frontend/logical/planner.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/frontend/semantic/symbol_generator.hpp"
using namespace query;
namespace {
class PlanChecker : public LogicalOperatorVisitor {
public:
using LogicalOperatorVisitor::Visit;
using LogicalOperatorVisitor::PostVisit;
PlanChecker(std::list<size_t> types) : types_(types) {}
void Visit(CreateOp &op) override { AssertType(op); }
void Visit(ScanAll &op) override { AssertType(op); }
void Visit(Expand &op) override { AssertType(op); }
void Visit(NodeFilter &op) override { AssertType(op); }
void Visit(EdgeFilter &op) override { AssertType(op); }
void Visit(Produce &op) override { AssertType(op); }
private:
void AssertType(const LogicalOperator &op) {
ASSERT_FALSE(types_.empty());
ASSERT_EQ(types_.back(), typeid(op).hash_code());
types_.pop_back();
}
std::list<size_t> types_;
};
// Returns a `(name1) -[name2]- (name3) ...` pattern.
auto GetPattern(AstTreeStorage &storage, std::vector<std::string> names) {
bool is_node{true};
auto pattern = storage.Create<Pattern>();
for (auto &name : names) {
PatternAtom *atom;
auto identifier = storage.Create<Identifier>(name);
if (is_node) {
atom = storage.Create<NodeAtom>(identifier);
} else {
atom = storage.Create<EdgeAtom>(identifier);
}
pattern->atoms_.emplace_back(atom);
is_node = !is_node;
}
return pattern;
}
TEST(TestLogicalPlanner, MatchNodeReturn) {
// Test MATCH (n) RETURN n AS n
AstTreeStorage storage;
auto match = storage.Create<Match>();
match->patterns_.emplace_back(GetPattern(storage, {"n"}));
auto query = storage.query();
query->clauses_.emplace_back(match);
auto named_expr = storage.Create<NamedExpression>();
named_expr->name_ = "n";
named_expr->expression_ = storage.Create<Identifier>("n");
auto ret = storage.Create<Return>();
ret->named_expressions_.emplace_back(named_expr);
query->clauses_.emplace_back(ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
auto plan = MakeLogicalPlan(*query, symbol_table);
std::list<size_t> expected_types;
expected_types.emplace_back(typeid(ScanAll).hash_code());
expected_types.emplace_back(typeid(Produce).hash_code());
PlanChecker plan_checker(expected_types);
plan->Accept(plan_checker);
}
TEST(TestLogicalPlanner, CreateNodeReturn) {
// Test CREATE (n) RETURN n AS n
AstTreeStorage storage;
auto create = storage.Create<Create>();
create->patterns_.emplace_back(GetPattern(storage, {"n"}));
auto query = storage.query();
query->clauses_.emplace_back(create);
auto named_expr = storage.Create<NamedExpression>();
named_expr->name_ = "n";
named_expr->expression_ = storage.Create<Identifier>("n");
auto ret = storage.Create<Return>();
ret->named_expressions_.emplace_back(named_expr);
query->clauses_.emplace_back(ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
auto plan = MakeLogicalPlan(*query, symbol_table);
std::list<size_t> expected_types;
expected_types.emplace_back(typeid(CreateOp).hash_code());
expected_types.emplace_back(typeid(Produce).hash_code());
PlanChecker plan_checker(expected_types);
plan->Accept(plan_checker);
}
TEST(TestLogicalPlanner, MatchLabeledNodes) {
// Test MATCH (n :label) RETURN n AS n
AstTreeStorage storage;
auto pattern = storage.Create<Pattern>();
auto node_atom = storage.Create<NodeAtom>(storage.Create<Identifier>("n"));
std::string label("label");
node_atom->labels_.emplace_back(&label);
pattern->atoms_.emplace_back(node_atom);
auto match = storage.Create<Match>();
match->patterns_.emplace_back(pattern);
auto query = storage.query();
query->clauses_.emplace_back(match);
auto named_expr = storage.Create<NamedExpression>();
named_expr->name_ = "n";
named_expr->expression_ = storage.Create<Identifier>("n");
auto ret = storage.Create<Return>();
ret->named_expressions_.emplace_back(named_expr);
query->clauses_.emplace_back(ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
auto plan = MakeLogicalPlan(*query, symbol_table);
std::list<size_t> expected_types;
expected_types.emplace_back(typeid(ScanAll).hash_code());
expected_types.emplace_back(typeid(NodeFilter).hash_code());
expected_types.emplace_back(typeid(Produce).hash_code());
PlanChecker plan_checker(expected_types);
plan->Accept(plan_checker);
}
TEST(TestLogicalPlanner, MatchPathReturn) {
// Test MATCH (n) -[r :relationship]- (m) RETURN n AS n
AstTreeStorage storage;
auto match = storage.Create<Match>();
auto pattern = GetPattern(storage, {"n", "r", "m"});
match->patterns_.emplace_back(pattern);
auto edge_atom = dynamic_cast<EdgeAtom*>(pattern->atoms_[1]);
std::string relationship("relationship");
edge_atom->edge_types_.emplace_back(&relationship);
auto query = storage.query();
query->clauses_.emplace_back(match);
auto named_expr = storage.Create<NamedExpression>();
named_expr->name_ = "n";
named_expr->expression_ = storage.Create<Identifier>("n");
auto ret = storage.Create<Return>();
ret->named_expressions_.emplace_back(named_expr);
query->clauses_.emplace_back(ret);
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
query->Accept(symbol_generator);
auto plan = MakeLogicalPlan(*query, symbol_table);
std::list<size_t> expected_types;
expected_types.emplace_back(typeid(ScanAll).hash_code());
expected_types.emplace_back(typeid(Expand).hash_code());
expected_types.emplace_back(typeid(EdgeFilter).hash_code());
expected_types.emplace_back(typeid(Produce).hash_code());
PlanChecker plan_checker(expected_types);
plan->Accept(plan_checker);
}
}