Plan optional matching

Summary:
Support OPTIONAL MATCH in test macros.
Test planning Optional.

Reviewers: florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D322
This commit is contained in:
Teon Banek 2017-04-28 10:37:49 +02:00
parent 3b9d13f1e1
commit 9653c703dc
6 changed files with 127 additions and 38 deletions

View File

@ -649,6 +649,7 @@ class Match : public Clause {
protected:
Match(int uid) : Clause(uid) {}
Match(int uid, bool optional) : Clause(uid), optional_(optional) {}
};
/** @brief Defines the order for sorting values (ascending or descending). */

View File

@ -363,7 +363,9 @@ bool Expand::ExpandCursor::HandleExistingNode(const VertexAccessor new_node,
NodeFilter::NodeFilter(const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol, const NodeAtom *node_atom)
: input_(input), input_symbol_(input_symbol), node_atom_(node_atom) {}
: input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol),
node_atom_(node_atom) {}
ACCEPT_WITH_INPUT(NodeFilter)
@ -410,7 +412,9 @@ bool NodeFilter::NodeFilterCursor::VertexPasses(
EdgeFilter::EdgeFilter(const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol, const EdgeAtom *edge_atom)
: input_(input), input_symbol_(input_symbol), edge_atom_(edge_atom) {}
: input_(input ? input : std::make_shared<Once>()),
input_symbol_(input_symbol),
edge_atom_(edge_atom) {}
ACCEPT_WITH_INPUT(EdgeFilter)
@ -458,9 +462,10 @@ bool EdgeFilter::EdgeFilterCursor::EdgePasses(const EdgeAccessor &edge,
return true;
}
Filter::Filter(const std::shared_ptr<LogicalOperator> &input_,
Expression *expression_)
: input_(input_), expression_(expression_) {}
Filter::Filter(const std::shared_ptr<LogicalOperator> &input,
Expression *expression)
: input_(input ? input : std::make_shared<Once>()),
expression_(expression) {}
ACCEPT_WITH_INPUT(Filter)

View File

@ -434,7 +434,7 @@ class NodeFilter : public LogicalOperator {
public:
/** @brief Construct @c NodeFilter.
*
* @param input Required, preceding @c LogicalOperator.
* @param input Optional, preceding @c LogicalOperator.
* @param input_symbol @c Symbol where the node to be filtered is stored.
* @param node_atom @c NodeAtom with labels and properties to filter by.
*/
@ -475,7 +475,7 @@ class EdgeFilter : public LogicalOperator {
public:
/** @brief Construct @c EdgeFilter.
*
* @param input Required, preceding @c LogicalOperator.
* @param input Optional, preceding @c LogicalOperator.
* @param input_symbol @c Symbol where the edge to be filtered is stored.
* @param edge_atom @c EdgeAtom with edge types and properties to filter by.
*/
@ -1231,6 +1231,10 @@ class Optional : public LogicalOperator {
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
auto input() const { return input_; }
auto optional() const { return optional_; }
const auto &optional_symbols() const { return optional_symbols_; }
private:
const std::shared_ptr<LogicalOperator> input_;
const std::shared_ptr<LogicalOperator> optional_;

View File

@ -101,18 +101,36 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
return last_op;
}
// Contextual information used for generating match operators.
struct MatchContext {
const SymbolTable &symbol_table;
// Already bound symbols, which are used to determine whether the operator
// should reference them or establish new. This is both read from and written
// to during generation.
std::unordered_set<int> &bound_symbols;
// Determines whether the match should see the new graph state or not.
GraphView graph_view = GraphView::OLD;
// Symbols for edges established in match, used to ensure Cyphermorphism.
std::unordered_set<Symbol, Symbol::Hash> edge_symbols;
// All the newly established symbols in match.
std::vector<Symbol> new_symbols;
};
// Generates operators for matching the given pattern and appends them to
// input_op. Fills the context with all the new symbols and edge symbols.
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols,
std::vector<Symbol> &edge_symbols,
GraphView graph_view = GraphView::OLD) {
MatchContext &context) {
auto &bound_symbols = context.bound_symbols;
const auto &symbol_table = context.symbol_table;
auto base = [&](NodeAtom *node) {
LogicalOperator *last_op = input_op;
// If the first atom binds a symbol, we generate a ScanAll which writes it.
// Otherwise, someone else generates it (e.g. a previous ScanAll).
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
const auto &node_symbol = symbol_table.at(*node->identifier_);
if (BindSymbol(bound_symbols, node_symbol)) {
last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op),
graph_view);
context.graph_view);
context.new_symbols.emplace_back(node_symbol);
}
// Even though we may skip generating ScanAll, we still want to add a filter
// in case this atom adds more labels/properties for filtering.
@ -129,28 +147,36 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
// If the expand symbols were already bound, then we need to indicate
// that they exist. The Expand will then check whether the pattern holds
// instead of writing the expansion to symbols.
const auto &node_symbol = symbol_table.at(*node->identifier_);
auto existing_node = false;
auto existing_edge = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
if (!BindSymbol(bound_symbols, node_symbol)) {
existing_node = true;
} else {
context.new_symbols.emplace_back(node_symbol);
}
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
auto existing_edge = false;
if (!BindSymbol(bound_symbols, edge_symbol)) {
existing_edge = true;
} else {
context.new_symbols.emplace_back(edge_symbol);
}
last_op =
new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
input_symbol, existing_node, existing_edge, graph_view);
last_op = new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
input_symbol, existing_node, existing_edge,
context.graph_view);
if (!existing_edge) {
// Ensure Cyphermorphism (different edge symbols always map to different
// edges).
if (!edge_symbols.empty()) {
if (!context.edge_symbols.empty()) {
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
edge_symbols);
std::vector<Symbol>(context.edge_symbols.begin(),
context.edge_symbols.end()));
}
edge_symbols.emplace_back(edge_symbol);
}
// Insert edge_symbol after creating ExpandUniquenessFilter, so that we
// avoid filtering by the same edge we just expanded.
context.edge_symbols.insert(edge_symbol);
if (!edge->edge_types_.empty() || !edge->properties_.empty()) {
last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op),
symbol_table.at(*edge->identifier_), edge);
@ -167,16 +193,22 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
auto GenMatch(Match &match, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<int> &bound_symbols) {
auto last_op = input_op;
std::vector<Symbol> edge_symbols;
auto last_op = match.optional_ ? nullptr : input_op;
MatchContext context{symbol_table, bound_symbols};
for (auto pattern : match.patterns_) {
last_op = GenMatchForPattern(*pattern, last_op, symbol_table, bound_symbols,
edge_symbols);
last_op = GenMatchForPattern(*pattern, last_op, context);
}
if (match.where_) {
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
match.where_->expression_);
}
// Plan Optional after Filter. because with `OPTIONAL MATCH ... WHERE`,
// filtering is done while looking for the pattern.
if (match.optional_) {
last_op = new Optional(std::shared_ptr<LogicalOperator>(input_op),
std::shared_ptr<LogicalOperator>(last_op),
context.new_symbols);
}
return last_op;
}
@ -449,10 +481,8 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
// Copy the bound symbol set, because we don't want to use the updated version
// when generating the create part.
std::unordered_set<int> bound_symbols_copy(bound_symbols);
std::vector<Symbol> edge_symbols;
auto on_match =
GenMatchForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols_copy, edge_symbols, GraphView::NEW);
MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW};
auto on_match = GenMatchForPattern(*merge.pattern_, nullptr, context);
// Use the original bound_symbols, so we fill it with new symbols.
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table,
bound_symbols);

View File

@ -130,14 +130,13 @@ auto GetPattern(AstTreeStorage &storage, std::vector<PatternAtom *> atoms) {
}
///
/// This function creates an AST node which can store patterns and fills them
/// with given patterns.
/// This function fills an AST node which with given patterns.
///
/// The function is most commonly used to create Match and Create clauses.
///
template <class TWithPatterns>
auto GetWithPatterns(AstTreeStorage &storage, std::vector<Pattern *> patterns) {
auto with_patterns = storage.Create<TWithPatterns>();
auto GetWithPatterns(TWithPatterns *with_patterns,
std::vector<Pattern *> patterns) {
with_patterns->patterns_.insert(with_patterns->patterns_.begin(),
patterns.begin(), patterns.end());
return with_patterns;
@ -346,11 +345,16 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
#define NODE(...) query::test_common::GetNode(storage, __VA_ARGS__)
#define EDGE(...) query::test_common::GetEdge(storage, __VA_ARGS__)
#define PATTERN(...) query::test_common::GetPattern(storage, {__VA_ARGS__})
#define MATCH(...) \
query::test_common::GetWithPatterns<query::Match>(storage, {__VA_ARGS__})
#define OPTIONAL_MATCH(...) \
query::test_common::GetWithPatterns(storage.Create<query::Match>(true), \
{__VA_ARGS__})
#define MATCH(...) \
query::test_common::GetWithPatterns(storage.Create<query::Match>(), \
{__VA_ARGS__})
#define WHERE(expr) storage.Create<query::Where>((expr))
#define CREATE(...) \
query::test_common::GetWithPatterns<query::Create>(storage, {__VA_ARGS__})
#define CREATE(...) \
query::test_common::GetWithPatterns(storage.Create<query::Create>(), \
{__VA_ARGS__})
#define IDENT(name) storage.Create<query::Identifier>((name))
#define LITERAL(val) storage.Create<query::Literal>((val))
#define PROPERTY_LOOKUP(...) \

View File

@ -67,6 +67,11 @@ class PlanChecker : public LogicalOperatorVisitor {
op.input()->Accept(*this);
return false;
}
bool PreVisit(Optional &op) override {
CheckOp(op);
op.input()->Accept(*this);
return false;
}
std::list<BaseOpChecker *> checkers_;
@ -172,6 +177,20 @@ class ExpectMerge : public OpChecker<Merge> {
const std::list<BaseOpChecker *> &on_create_;
};
class ExpectOptional : public OpChecker<Optional> {
public:
ExpectOptional(const std::list<BaseOpChecker *> &optional)
: optional_(optional) {}
void ExpectOp(Optional &optional, const SymbolTable &symbol_table) override {
PlanChecker check_optional(optional_, symbol_table);
optional.optional()->Accept(check_optional);
}
private:
const std::list<BaseOpChecker *> &optional_;
};
auto MakeSymbolTable(query::Query &query) {
SymbolTable symbol_table;
SymbolGenerator symbol_generator(symbol_table);
@ -383,7 +402,7 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) {
CheckPlan(*query, ExpectScanAll(), ExpectExpand());
}
TEST(TestLogicalPlanner, MatchEdgeCycle) {
TEST(TestLogicalPlanner, MatchExistingEdge) {
// Test MATCH (n) -[r]- (m) -[r]- (j)
AstTreeStorage storage;
auto query = QUERY(
@ -392,6 +411,17 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) {
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand());
}
TEST(TestLogicalPlanner, MultiMatchExistingEdgeOtherEdge) {
// Test MATCH (n) -[r]- (m) MATCH (m) -[r]- (j) -[e]- (l)
AstTreeStorage storage;
auto query = QUERY(
MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))),
MATCH(PATTERN(NODE("m"), EDGE("r"), NODE("j"), EDGE("e"), NODE("l"))));
// We need ExpandUniquenessFilter for edge `e` against `r` in second MATCH.
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(),
ExpectExpand(), ExpectExpandUniquenessFilter<EdgeAccessor>());
}
TEST(TestLogicalPlanner, MatchWithReturn) {
// Test MATCH (old) WITH old AS new RETURN new AS new
AstTreeStorage storage;
@ -624,4 +654,19 @@ TEST(TestLogicalPlanner, MatchMerge) {
on_create.clear();
}
TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) {
// Test MATCH (n) OPTIONAL MATCH (n) -[r]- (m) WHERE m.prop < 42 RETURN r
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("prop");
AstTreeStorage storage;
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
OPTIONAL_MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))),
WHERE(LESS(PROPERTY_LOOKUP("m", prop), LITERAL(42))),
RETURN(IDENT("r"), AS("r")));
std::list<BaseOpChecker *> optional{new ExpectScanAll(), new ExpectExpand(),
new ExpectFilter()};
CheckPlan(*query, ExpectScanAll(), ExpectOptional(optional), ExpectProduce());
}
} // namespace