Plan optional matching
Summary: Support OPTIONAL MATCH in test macros. Test planning Optional. Reviewers: florijan, mislav.bradac Reviewed By: florijan Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D322
This commit is contained in:
parent
3b9d13f1e1
commit
9653c703dc
@ -649,6 +649,7 @@ class Match : public Clause {
|
||||
|
||||
protected:
|
||||
Match(int uid) : Clause(uid) {}
|
||||
Match(int uid, bool optional) : Clause(uid), optional_(optional) {}
|
||||
};
|
||||
|
||||
/** @brief Defines the order for sorting values (ascending or descending). */
|
||||
|
@ -363,7 +363,9 @@ bool Expand::ExpandCursor::HandleExistingNode(const VertexAccessor new_node,
|
||||
|
||||
NodeFilter::NodeFilter(const std::shared_ptr<LogicalOperator> &input,
|
||||
Symbol input_symbol, const NodeAtom *node_atom)
|
||||
: input_(input), input_symbol_(input_symbol), node_atom_(node_atom) {}
|
||||
: input_(input ? input : std::make_shared<Once>()),
|
||||
input_symbol_(input_symbol),
|
||||
node_atom_(node_atom) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(NodeFilter)
|
||||
|
||||
@ -410,7 +412,9 @@ bool NodeFilter::NodeFilterCursor::VertexPasses(
|
||||
|
||||
EdgeFilter::EdgeFilter(const std::shared_ptr<LogicalOperator> &input,
|
||||
Symbol input_symbol, const EdgeAtom *edge_atom)
|
||||
: input_(input), input_symbol_(input_symbol), edge_atom_(edge_atom) {}
|
||||
: input_(input ? input : std::make_shared<Once>()),
|
||||
input_symbol_(input_symbol),
|
||||
edge_atom_(edge_atom) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(EdgeFilter)
|
||||
|
||||
@ -458,9 +462,10 @@ bool EdgeFilter::EdgeFilterCursor::EdgePasses(const EdgeAccessor &edge,
|
||||
return true;
|
||||
}
|
||||
|
||||
Filter::Filter(const std::shared_ptr<LogicalOperator> &input_,
|
||||
Expression *expression_)
|
||||
: input_(input_), expression_(expression_) {}
|
||||
Filter::Filter(const std::shared_ptr<LogicalOperator> &input,
|
||||
Expression *expression)
|
||||
: input_(input ? input : std::make_shared<Once>()),
|
||||
expression_(expression) {}
|
||||
|
||||
ACCEPT_WITH_INPUT(Filter)
|
||||
|
||||
|
@ -434,7 +434,7 @@ class NodeFilter : public LogicalOperator {
|
||||
public:
|
||||
/** @brief Construct @c NodeFilter.
|
||||
*
|
||||
* @param input Required, preceding @c LogicalOperator.
|
||||
* @param input Optional, preceding @c LogicalOperator.
|
||||
* @param input_symbol @c Symbol where the node to be filtered is stored.
|
||||
* @param node_atom @c NodeAtom with labels and properties to filter by.
|
||||
*/
|
||||
@ -475,7 +475,7 @@ class EdgeFilter : public LogicalOperator {
|
||||
public:
|
||||
/** @brief Construct @c EdgeFilter.
|
||||
*
|
||||
* @param input Required, preceding @c LogicalOperator.
|
||||
* @param input Optional, preceding @c LogicalOperator.
|
||||
* @param input_symbol @c Symbol where the edge to be filtered is stored.
|
||||
* @param edge_atom @c EdgeAtom with edge types and properties to filter by.
|
||||
*/
|
||||
@ -1231,6 +1231,10 @@ class Optional : public LogicalOperator {
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
auto input() const { return input_; }
|
||||
auto optional() const { return optional_; }
|
||||
const auto &optional_symbols() const { return optional_symbols_; }
|
||||
|
||||
private:
|
||||
const std::shared_ptr<LogicalOperator> input_;
|
||||
const std::shared_ptr<LogicalOperator> optional_;
|
||||
|
@ -101,18 +101,36 @@ auto GenCreate(Create &create, LogicalOperator *input_op,
|
||||
return last_op;
|
||||
}
|
||||
|
||||
// Contextual information used for generating match operators.
|
||||
struct MatchContext {
|
||||
const SymbolTable &symbol_table;
|
||||
// Already bound symbols, which are used to determine whether the operator
|
||||
// should reference them or establish new. This is both read from and written
|
||||
// to during generation.
|
||||
std::unordered_set<int> &bound_symbols;
|
||||
// Determines whether the match should see the new graph state or not.
|
||||
GraphView graph_view = GraphView::OLD;
|
||||
// Symbols for edges established in match, used to ensure Cyphermorphism.
|
||||
std::unordered_set<Symbol, Symbol::Hash> edge_symbols;
|
||||
// All the newly established symbols in match.
|
||||
std::vector<Symbol> new_symbols;
|
||||
};
|
||||
|
||||
// Generates operators for matching the given pattern and appends them to
|
||||
// input_op. Fills the context with all the new symbols and edge symbols.
|
||||
auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols,
|
||||
std::vector<Symbol> &edge_symbols,
|
||||
GraphView graph_view = GraphView::OLD) {
|
||||
MatchContext &context) {
|
||||
auto &bound_symbols = context.bound_symbols;
|
||||
const auto &symbol_table = context.symbol_table;
|
||||
auto base = [&](NodeAtom *node) {
|
||||
LogicalOperator *last_op = input_op;
|
||||
// If the first atom binds a symbol, we generate a ScanAll which writes it.
|
||||
// Otherwise, someone else generates it (e.g. a previous ScanAll).
|
||||
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
|
||||
const auto &node_symbol = symbol_table.at(*node->identifier_);
|
||||
if (BindSymbol(bound_symbols, node_symbol)) {
|
||||
last_op = new ScanAll(node, std::shared_ptr<LogicalOperator>(last_op),
|
||||
graph_view);
|
||||
context.graph_view);
|
||||
context.new_symbols.emplace_back(node_symbol);
|
||||
}
|
||||
// Even though we may skip generating ScanAll, we still want to add a filter
|
||||
// in case this atom adds more labels/properties for filtering.
|
||||
@ -129,28 +147,36 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
// If the expand symbols were already bound, then we need to indicate
|
||||
// that they exist. The Expand will then check whether the pattern holds
|
||||
// instead of writing the expansion to symbols.
|
||||
const auto &node_symbol = symbol_table.at(*node->identifier_);
|
||||
auto existing_node = false;
|
||||
auto existing_edge = false;
|
||||
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
|
||||
if (!BindSymbol(bound_symbols, node_symbol)) {
|
||||
existing_node = true;
|
||||
} else {
|
||||
context.new_symbols.emplace_back(node_symbol);
|
||||
}
|
||||
const auto &edge_symbol = symbol_table.at(*edge->identifier_);
|
||||
auto existing_edge = false;
|
||||
if (!BindSymbol(bound_symbols, edge_symbol)) {
|
||||
existing_edge = true;
|
||||
} else {
|
||||
context.new_symbols.emplace_back(edge_symbol);
|
||||
}
|
||||
last_op =
|
||||
new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
|
||||
input_symbol, existing_node, existing_edge, graph_view);
|
||||
last_op = new Expand(node, edge, std::shared_ptr<LogicalOperator>(last_op),
|
||||
input_symbol, existing_node, existing_edge,
|
||||
context.graph_view);
|
||||
if (!existing_edge) {
|
||||
// Ensure Cyphermorphism (different edge symbols always map to different
|
||||
// edges).
|
||||
if (!edge_symbols.empty()) {
|
||||
if (!context.edge_symbols.empty()) {
|
||||
last_op = new ExpandUniquenessFilter<EdgeAccessor>(
|
||||
std::shared_ptr<LogicalOperator>(last_op), edge_symbol,
|
||||
edge_symbols);
|
||||
std::vector<Symbol>(context.edge_symbols.begin(),
|
||||
context.edge_symbols.end()));
|
||||
}
|
||||
edge_symbols.emplace_back(edge_symbol);
|
||||
}
|
||||
// Insert edge_symbol after creating ExpandUniquenessFilter, so that we
|
||||
// avoid filtering by the same edge we just expanded.
|
||||
context.edge_symbols.insert(edge_symbol);
|
||||
if (!edge->edge_types_.empty() || !edge->properties_.empty()) {
|
||||
last_op = new EdgeFilter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
symbol_table.at(*edge->identifier_), edge);
|
||||
@ -167,16 +193,22 @@ auto GenMatchForPattern(Pattern &pattern, LogicalOperator *input_op,
|
||||
auto GenMatch(Match &match, LogicalOperator *input_op,
|
||||
const SymbolTable &symbol_table,
|
||||
std::unordered_set<int> &bound_symbols) {
|
||||
auto last_op = input_op;
|
||||
std::vector<Symbol> edge_symbols;
|
||||
auto last_op = match.optional_ ? nullptr : input_op;
|
||||
MatchContext context{symbol_table, bound_symbols};
|
||||
for (auto pattern : match.patterns_) {
|
||||
last_op = GenMatchForPattern(*pattern, last_op, symbol_table, bound_symbols,
|
||||
edge_symbols);
|
||||
last_op = GenMatchForPattern(*pattern, last_op, context);
|
||||
}
|
||||
if (match.where_) {
|
||||
last_op = new Filter(std::shared_ptr<LogicalOperator>(last_op),
|
||||
match.where_->expression_);
|
||||
}
|
||||
// Plan Optional after Filter. because with `OPTIONAL MATCH ... WHERE`,
|
||||
// filtering is done while looking for the pattern.
|
||||
if (match.optional_) {
|
||||
last_op = new Optional(std::shared_ptr<LogicalOperator>(input_op),
|
||||
std::shared_ptr<LogicalOperator>(last_op),
|
||||
context.new_symbols);
|
||||
}
|
||||
return last_op;
|
||||
}
|
||||
|
||||
@ -449,10 +481,8 @@ auto GenMerge(query::Merge &merge, LogicalOperator *input_op,
|
||||
// Copy the bound symbol set, because we don't want to use the updated version
|
||||
// when generating the create part.
|
||||
std::unordered_set<int> bound_symbols_copy(bound_symbols);
|
||||
std::vector<Symbol> edge_symbols;
|
||||
auto on_match =
|
||||
GenMatchForPattern(*merge.pattern_, nullptr, symbol_table,
|
||||
bound_symbols_copy, edge_symbols, GraphView::NEW);
|
||||
MatchContext context{symbol_table, bound_symbols_copy, GraphView::NEW};
|
||||
auto on_match = GenMatchForPattern(*merge.pattern_, nullptr, context);
|
||||
// Use the original bound_symbols, so we fill it with new symbols.
|
||||
auto on_create = GenCreateForPattern(*merge.pattern_, nullptr, symbol_table,
|
||||
bound_symbols);
|
||||
|
@ -130,14 +130,13 @@ auto GetPattern(AstTreeStorage &storage, std::vector<PatternAtom *> atoms) {
|
||||
}
|
||||
|
||||
///
|
||||
/// This function creates an AST node which can store patterns and fills them
|
||||
/// with given patterns.
|
||||
/// This function fills an AST node which with given patterns.
|
||||
///
|
||||
/// The function is most commonly used to create Match and Create clauses.
|
||||
///
|
||||
template <class TWithPatterns>
|
||||
auto GetWithPatterns(AstTreeStorage &storage, std::vector<Pattern *> patterns) {
|
||||
auto with_patterns = storage.Create<TWithPatterns>();
|
||||
auto GetWithPatterns(TWithPatterns *with_patterns,
|
||||
std::vector<Pattern *> patterns) {
|
||||
with_patterns->patterns_.insert(with_patterns->patterns_.begin(),
|
||||
patterns.begin(), patterns.end());
|
||||
return with_patterns;
|
||||
@ -346,11 +345,16 @@ auto GetMerge(AstTreeStorage &storage, Pattern *pattern, OnMatch on_match,
|
||||
#define NODE(...) query::test_common::GetNode(storage, __VA_ARGS__)
|
||||
#define EDGE(...) query::test_common::GetEdge(storage, __VA_ARGS__)
|
||||
#define PATTERN(...) query::test_common::GetPattern(storage, {__VA_ARGS__})
|
||||
#define MATCH(...) \
|
||||
query::test_common::GetWithPatterns<query::Match>(storage, {__VA_ARGS__})
|
||||
#define OPTIONAL_MATCH(...) \
|
||||
query::test_common::GetWithPatterns(storage.Create<query::Match>(true), \
|
||||
{__VA_ARGS__})
|
||||
#define MATCH(...) \
|
||||
query::test_common::GetWithPatterns(storage.Create<query::Match>(), \
|
||||
{__VA_ARGS__})
|
||||
#define WHERE(expr) storage.Create<query::Where>((expr))
|
||||
#define CREATE(...) \
|
||||
query::test_common::GetWithPatterns<query::Create>(storage, {__VA_ARGS__})
|
||||
#define CREATE(...) \
|
||||
query::test_common::GetWithPatterns(storage.Create<query::Create>(), \
|
||||
{__VA_ARGS__})
|
||||
#define IDENT(name) storage.Create<query::Identifier>((name))
|
||||
#define LITERAL(val) storage.Create<query::Literal>((val))
|
||||
#define PROPERTY_LOOKUP(...) \
|
||||
|
@ -67,6 +67,11 @@ class PlanChecker : public LogicalOperatorVisitor {
|
||||
op.input()->Accept(*this);
|
||||
return false;
|
||||
}
|
||||
bool PreVisit(Optional &op) override {
|
||||
CheckOp(op);
|
||||
op.input()->Accept(*this);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::list<BaseOpChecker *> checkers_;
|
||||
|
||||
@ -172,6 +177,20 @@ class ExpectMerge : public OpChecker<Merge> {
|
||||
const std::list<BaseOpChecker *> &on_create_;
|
||||
};
|
||||
|
||||
class ExpectOptional : public OpChecker<Optional> {
|
||||
public:
|
||||
ExpectOptional(const std::list<BaseOpChecker *> &optional)
|
||||
: optional_(optional) {}
|
||||
|
||||
void ExpectOp(Optional &optional, const SymbolTable &symbol_table) override {
|
||||
PlanChecker check_optional(optional_, symbol_table);
|
||||
optional.optional()->Accept(check_optional);
|
||||
}
|
||||
|
||||
private:
|
||||
const std::list<BaseOpChecker *> &optional_;
|
||||
};
|
||||
|
||||
auto MakeSymbolTable(query::Query &query) {
|
||||
SymbolTable symbol_table;
|
||||
SymbolGenerator symbol_generator(symbol_table);
|
||||
@ -383,7 +402,7 @@ TEST(TestLogicalPlanner, MultiMatchSameStart) {
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchEdgeCycle) {
|
||||
TEST(TestLogicalPlanner, MatchExistingEdge) {
|
||||
// Test MATCH (n) -[r]- (m) -[r]- (j)
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(
|
||||
@ -392,6 +411,17 @@ TEST(TestLogicalPlanner, MatchEdgeCycle) {
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MultiMatchExistingEdgeOtherEdge) {
|
||||
// Test MATCH (n) -[r]- (m) MATCH (m) -[r]- (j) -[e]- (l)
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(
|
||||
MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))),
|
||||
MATCH(PATTERN(NODE("m"), EDGE("r"), NODE("j"), EDGE("e"), NODE("l"))));
|
||||
// We need ExpandUniquenessFilter for edge `e` against `r` in second MATCH.
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectExpand(), ExpectExpand(),
|
||||
ExpectExpand(), ExpectExpandUniquenessFilter<EdgeAccessor>());
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchWithReturn) {
|
||||
// Test MATCH (old) WITH old AS new RETURN new AS new
|
||||
AstTreeStorage storage;
|
||||
@ -624,4 +654,19 @@ TEST(TestLogicalPlanner, MatchMerge) {
|
||||
on_create.clear();
|
||||
}
|
||||
|
||||
TEST(TestLogicalPlanner, MatchOptionalMatchWhereReturn) {
|
||||
// Test MATCH (n) OPTIONAL MATCH (n) -[r]- (m) WHERE m.prop < 42 RETURN r
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("prop");
|
||||
AstTreeStorage storage;
|
||||
auto query = QUERY(MATCH(PATTERN(NODE("n"))),
|
||||
OPTIONAL_MATCH(PATTERN(NODE("n"), EDGE("r"), NODE("m"))),
|
||||
WHERE(LESS(PROPERTY_LOOKUP("m", prop), LITERAL(42))),
|
||||
RETURN(IDENT("r"), AS("r")));
|
||||
std::list<BaseOpChecker *> optional{new ExpectScanAll(), new ExpectExpand(),
|
||||
new ExpectFilter()};
|
||||
CheckPlan(*query, ExpectScanAll(), ExpectOptional(optional), ExpectProduce());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user