Inline filter inside ExpandVariable

Summary:
Reorder class definition in ast.hpp.
Test inlining filters in ExpandVariable.

Reviewers: florijan, mislav.bradac

Reviewed By: florijan

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D726
This commit is contained in:
Teon Banek 2017-08-30 15:37:00 +02:00
parent e68f7ea536
commit 52709ad04c
7 changed files with 374 additions and 284 deletions

View File

@ -71,6 +71,9 @@ endif()
# TODO: set here 17 once it will be available in the cmake version (3.8)
set(cxx_standard 14)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++1z -Wall -Wno-c++1z-extensions")
# Don't omit frame pointer in RelWithDebInfo, for additional callchain debug.
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO
"${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -fno-omit-frame-pointer")
# -----------------------------------------------------------------------------
# dir variables
@ -99,7 +102,7 @@ endif()
# default build type is debug
if ("${CMAKE_BUILD_TYPE}" STREQUAL "")
set(CMAKE_BUILD_TYPE "debug")
set(CMAKE_BUILD_TYPE "Debug")
endif()
message(STATUS "CMake build type: ${CMAKE_BUILD_TYPE}")
# -----------------------------------------------------------------------------
@ -356,7 +359,7 @@ string(STRIP ${COMMIT_HASH} COMMIT_HASH)
set(MEMGRAPH_BUILD_NAME
"memgraph_${COMMIT_NO}_${COMMIT_HASH}_${COMMIT_BRANCH}_${CMAKE_BUILD_TYPE}")
add_custom_target(memgraph_link_target ALL
COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_BINARY_DIR}/${MEMGRAPH_BUILD_NAME} ${CMAKE_BINARY_DIR}/memgraph DEPENDS ${MEMGRAPH_BUILD_NAME})
COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_BINARY_DIR}/${MEMGRAPH_BUILD_NAME} ${CMAKE_BINARY_DIR}/memgraph DEPENDS ${MEMGRAPH_BUILD_NAME})
# -----------------------------------------------------------------------------
# memgraph main executable

View File

@ -77,6 +77,8 @@ class Tree : public ::utils::Visitable<HierarchicalTreeVisitor>,
const int uid_;
};
// Expressions
class Expression : public Tree {
friend class AstTreeStorage;
@ -87,6 +89,29 @@ class Expression : public Tree {
Expression(int uid) : Tree(uid) {}
};
class Where : public Tree {
friend class AstTreeStorage;
public:
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
expression_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
Where *Clone(AstTreeStorage &storage) const override {
return storage.Create<Where>(expression_->Clone(storage));
}
Expression *expression_ = nullptr;
protected:
Where(int uid) : Tree(uid) {}
Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {}
};
class BinaryOperator : public Expression {
friend class AstTreeStorage;
@ -840,6 +865,42 @@ class Aggregation : public BinaryOperator {
}
};
class All : public Expression {
friend class AstTreeStorage;
public:
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
where_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
All *Clone(AstTreeStorage &storage) const override {
return storage.Create<All>(identifier_->Clone(storage),
list_expression_->Clone(storage),
where_->Clone(storage));
}
Identifier *identifier_ = nullptr;
Expression *list_expression_ = nullptr;
Where *where_ = nullptr;
protected:
All(int uid, Identifier *identifier, Expression *list_expression,
Where *where)
: Expression(uid),
identifier_(identifier),
list_expression_(list_expression),
where_(where) {
debug_assert(identifier, "identifier must not be nullptr");
debug_assert(list_expression, "list_expression must not be nullptr");
debug_assert(where, "where must not be nullptr");
}
};
class NamedExpression : public Tree {
friend class AstTreeStorage;
@ -877,6 +938,8 @@ class NamedExpression : public Tree {
token_position_(token_position) {}
};
// Pattern atoms
class PatternAtom : public Tree {
friend class AstTreeStorage;
@ -1026,15 +1089,6 @@ class BreadthFirstAtom : public EdgeAtom {
max_depth_(max_depth) {}
};
class Clause : public Tree {
friend class AstTreeStorage;
public:
Clause(int uid) : Tree(uid) {}
Clause *Clone(AstTreeStorage &storage) const override = 0;
};
class Pattern : public Tree {
friend class AstTreeStorage;
@ -1065,6 +1119,17 @@ class Pattern : public Tree {
Pattern(int uid) : Tree(uid) {}
};
// Clauses
class Clause : public Tree {
friend class AstTreeStorage;
public:
Clause(int uid) : Tree(uid) {}
Clause *Clone(AstTreeStorage &storage) const override = 0;
};
class Query : public Tree {
friend class AstTreeStorage;
@ -1120,65 +1185,6 @@ class Create : public Clause {
std::vector<Pattern *> patterns_;
};
class Where : public Tree {
friend class AstTreeStorage;
public:
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
expression_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
Where *Clone(AstTreeStorage &storage) const override {
return storage.Create<Where>(expression_->Clone(storage));
}
Expression *expression_ = nullptr;
protected:
Where(int uid) : Tree(uid) {}
Where(int uid, Expression *expression) : Tree(uid), expression_(expression) {}
};
class All : public Expression {
friend class AstTreeStorage;
public:
DEFVISITABLE(TreeVisitor<TypedValue>);
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
identifier_->Accept(visitor) && list_expression_->Accept(visitor) &&
where_->Accept(visitor);
}
return visitor.PostVisit(*this);
}
All *Clone(AstTreeStorage &storage) const override {
return storage.Create<All>(identifier_->Clone(storage),
list_expression_->Clone(storage),
where_->Clone(storage));
}
Identifier *identifier_ = nullptr;
Expression *list_expression_ = nullptr;
Where *where_ = nullptr;
protected:
All(int uid, Identifier *identifier, Expression *list_expression,
Where *where)
: Expression(uid),
identifier_(identifier),
list_expression_(list_expression),
where_(where) {
debug_assert(identifier, "identifier must not be nullptr");
debug_assert(list_expression, "list_expression must not be nullptr");
debug_assert(where, "where must not be nullptr");
}
};
class Match : public Clause {
friend class AstTreeStorage;

View File

@ -44,6 +44,18 @@ void ExpectType(Symbol symbol, TypedValue value, TypedValue::Type expected) {
symbol.name(), value.type());
}
// Returns boolean result of evaluating filter expression. Null is treated as
// false. Other non boolean values raise a QueryRuntimeException.
bool EvaluateFilter(ExpressionEvaluator &evaluator, Expression *filter) {
TypedValue result = filter->Accept(evaluator);
// Null is treated like false.
if (result.IsNull()) return false;
if (result.type() != TypedValue::Type::Bool)
throw QueryRuntimeException(
"Filter expression must be a bool or null, but got {}.", result.type());
return result.Value<bool>();
}
} // namespace
bool Once::OnceCursor::Pull(Frame &, const SymbolTable &) {
@ -239,7 +251,7 @@ ScanAll::ScanAll(const std::shared_ptr<LogicalOperator> &input,
output_symbol_(output_symbol),
graph_view_(graph_view) {
permanent_assert(graph_view != GraphView::AS_IS,
"ScanAll must have explicitly defined GraphView")
"ScanAll must have explicitly defined GraphView");
}
ACCEPT_WITH_INPUT(ScanAll)
@ -300,10 +312,10 @@ std::unique_ptr<Cursor> ScanAllByLabelPropertyRange::MakeCursor(
ExpressionEvaluator evaluator(frame, symbol_table, db, graph_view_);
auto convert = [&evaluator](const auto &bound)
-> std::experimental::optional<utils::Bound<PropertyValue>> {
if (!bound) return std::experimental::nullopt;
return std::experimental::make_optional(utils::Bound<PropertyValue>(
bound.value().value()->Accept(evaluator), bound.value().type()));
};
if (!bound) return std::experimental::nullopt;
return std::experimental::make_optional(utils::Bound<PropertyValue>(
bound.value().value()->Accept(evaluator), bound.value().type()));
};
return db.Vertices(label_, property_, convert(lower_bound()),
convert(upper_bound()), graph_view_ == GraphView::NEW);
};
@ -531,12 +543,14 @@ ExpandVariable::ExpandVariable(Symbol node_symbol, Symbol edge_symbol,
Expression *lower_bound, Expression *upper_bound,
const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol, bool existing_node,
bool existing_edge, GraphView graph_view)
bool existing_edge, GraphView graph_view,
Expression *filter)
: ExpandCommon(node_symbol, edge_symbol, direction, input, input_symbol,
existing_node, existing_edge, graph_view),
lower_bound_(lower_bound),
upper_bound_(upper_bound),
is_reverse_(is_reverse) {}
is_reverse_(is_reverse),
filter_(filter) {}
bool Expand::ExpandCursor::HandleExistingEdge(const EdgeAccessor &new_edge,
Frame &frame) const {
@ -612,8 +626,9 @@ class ExpandVariableCursor : public Cursor {
: self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {}
bool Pull(Frame &frame, const SymbolTable &symbol_table) override {
ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_);
while (true) {
if (Expand(frame)) return true;
if (Expand(frame, symbol_table)) return true;
if (PullInput(frame, symbol_table)) {
// if lower bound is zero we also yield empty paths
@ -625,8 +640,11 @@ class ExpandVariableCursor : public Cursor {
// take into account existing_edge when yielding empty paths
if ((!self_.existing_edge_ || edges_on_frame.empty()) &&
// Place the start vertex on the frame.
self_.HandleExistingNode(start_vertex, frame))
self_.HandleExistingNode(start_vertex, frame)) {
if (self_.filter_ && !EvaluateFilter(evaluator, self_.filter_))
continue;
return true;
}
}
// if lower bound is not zero, we just continue, the next
// loop iteration will attempt to expand and we're good
@ -793,7 +811,8 @@ class ExpandVariableCursor : public Cursor {
* case no more expansions are available from the current input
* vertex and another Pull from the input cursor should be performed.
*/
bool Expand(Frame &frame) {
bool Expand(Frame &frame, const SymbolTable &symbol_table) {
ExpressionEvaluator evaluator(frame, symbol_table, db_, self_.graph_view_);
// some expansions might not be valid due to
// edge uniqueness, existing_edge, existing_node criterions,
// so expand in a loop until either the input vertex is
@ -851,6 +870,10 @@ class ExpandVariableCursor : public Cursor {
auto edge_placement_result =
HandleEdgePlacement(current_edge.first, edges_on_frame);
if (edge_placement_result == EdgePlacementResult::MISMATCH) continue;
// Skip expanding out of filtered expansion. It is assumed that the
// expression does not use the vertex which has yet to be put on frame.
// Therefore, this check is done as soon as the edge is on the frame.
if (self_.filter_ && !EvaluateFilter(evaluator, self_.filter_)) continue;
VertexAccessor current_vertex =
current_edge.second == EdgeAtom::Direction::IN
@ -1050,16 +1073,7 @@ bool Filter::FilterCursor::Pull(Frame &frame, const SymbolTable &symbol_table) {
// and edges.
ExpressionEvaluator evaluator(frame, symbol_table, db_, GraphView::OLD);
while (input_cursor_->Pull(frame, symbol_table)) {
TypedValue result = self_.expression_->Accept(evaluator);
// Null is treated like false.
if (result.IsNull()) continue;
if (result.type() != TypedValue::Type::Bool)
throw QueryRuntimeException(
"Filter expression must be a bool or null, but got {}.",
result.type());
if (!result.Value<bool>()) continue;
return true;
if (EvaluateFilter(evaluator, self_.expression_)) return true;
}
return false;
}
@ -1203,11 +1217,11 @@ bool SetProperty::SetPropertyCursor::Pull(Frame &frame,
// Skip setting properties on Null (can occur in optional match).
break;
case TypedValue::Type::Map:
// Semantically modifying a map makes sense, but it's not supported due to
// all the copying we do (when PropertyValue -> TypedValue and in
// ExpressionEvaluator). So even though we set a map property here, that
// is never visible to the user and it's not stored.
// TODO: fix above described bug
// Semantically modifying a map makes sense, but it's not supported due to
// all the copying we do (when PropertyValue -> TypedValue and in
// ExpressionEvaluator). So even though we set a map property here, that
// is never visible to the user and it's not stored.
// TODO: fix above described bug
default:
throw QueryRuntimeException(
"Properties can only be set on Vertices and Edges");
@ -1737,14 +1751,14 @@ void Aggregate::AggregateCursor::Update(
*value_it = 1;
break;
case Aggregation::Op::COLLECT_LIST:
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
break;
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
break;
case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(evaluator);
if (key.type() != TypedValue::Type::String)
throw QueryRuntimeException("Map key must be a string");
value_it->Value<std::map<std::string, TypedValue>>().emplace(
key.Value<std::string>(), input_value);
auto key = agg_elem_it->key->Accept(evaluator);
if (key.type() != TypedValue::Type::String)
throw QueryRuntimeException("Map key must be a string");
value_it->Value<std::map<std::string, TypedValue>>().emplace(
key.Value<std::string>(), input_value);
break;
}
continue;
@ -1789,14 +1803,14 @@ void Aggregate::AggregateCursor::Update(
*value_it = *value_it + input_value;
break;
case Aggregation::Op::COLLECT_LIST:
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
break;
value_it->Value<std::vector<TypedValue>>().push_back(input_value);
break;
case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(evaluator);
if (key.type() != TypedValue::Type::String)
throw QueryRuntimeException("Map key must be a string");
value_it->Value<std::map<std::string, TypedValue>>().emplace(
key.Value<std::string>(), input_value);
auto key = agg_elem_it->key->Accept(evaluator);
if (key.type() != TypedValue::Type::String)
throw QueryRuntimeException("Map key must be a string");
value_it->Value<std::map<std::string, TypedValue>>().emplace(
key.Value<std::string>(), input_value);
break;
} // end switch over Aggregation::Op enum
} // end loop over all aggregations

View File

@ -633,7 +633,8 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon {
Expression *lower_bound, Expression *upper_bound,
const std::shared_ptr<LogicalOperator> &input,
Symbol input_symbol, bool existing_node, bool existing_edge,
GraphView graph_view = GraphView::AS_IS);
GraphView graph_view = GraphView::AS_IS,
Expression *filter = nullptr);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
@ -646,6 +647,7 @@ class ExpandVariable : public LogicalOperator, public ExpandCommon {
// True if the path should be written as expanding from node_symbol to
// input_symbol.
bool is_reverse_;
Expression *filter_;
};
/**

View File

@ -114,10 +114,9 @@ class UsedSymbolsCollector : public HierarchicalTreeVisitor {
const SymbolTable &symbol_table_;
};
bool HasBoundFilterSymbols(
const std::unordered_set<Symbol> &bound_symbols,
const std::pair<Expression *, std::unordered_set<Symbol>> &filter) {
for (const auto &symbol : filter.second) {
bool HasBoundFilterSymbols(const std::unordered_set<Symbol> &bound_symbols,
const Filters::FilterInfo &filter) {
for (const auto &symbol : filter.used_symbols) {
if (bound_symbols.find(symbol) == bound_symbols.end()) {
return false;
}
@ -357,8 +356,7 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
// two expressions, so we can have 0, 1 or 2 elements on the
// has_aggregation_stack for this Aggregation expression.
if (aggr.op_ == Aggregation::Op::COLLECT_MAP)
has_aggregation_.pop_back();
if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back();
if (aggr.expression1_)
has_aggregation_.back() = true;
else
@ -594,11 +592,178 @@ void AddMatching(const Match &match, SymbolTable &symbol_table,
matching);
}
// Iterates over `all_filters` joining them in one expression via
// `FilterAndOperator`. Filters which use unbound symbols are skipped, as well
// as those that fail the `predicate` function. The function takes a single
// argument, `FilterInfo`. All the joined filters are removed from
// `all_filters`.
template <class TPredicate>
Expression *ExtractFilters(const std::unordered_set<Symbol> &bound_symbols,
std::vector<Filters::FilterInfo> &all_filters,
AstTreeStorage &storage,
const TPredicate &predicate) {
Expression *filter_expr = nullptr;
for (auto filters_it = all_filters.begin();
filters_it != all_filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it) &&
predicate(*filters_it)) {
filter_expr = BoolJoin<FilterAndOperator>(storage, filter_expr,
filters_it->expression);
filters_it = all_filters.erase(filters_it);
} else {
filters_it++;
}
}
return filter_expr;
}
} // namespace
namespace impl {
// Returns false if the symbol was already bound, otherwise binds it and
// returns true.
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol) {
auto insertion = bound_symbols.insert(symbol);
return insertion.second;
}
Expression *FindExpandVariableFilter(
const std::unordered_set<Symbol> &bound_symbols,
const Symbol &expands_to_node,
std::vector<Filters::FilterInfo> &all_filters, AstTreeStorage &storage) {
return ExtractFilters(bound_symbols, all_filters, storage,
[&](const auto &filter) {
return filter.is_for_expand_variable &&
filter.used_symbols.find(expands_to_node) ==
filter.used_symbols.end();
});
}
LogicalOperator *GenFilters(LogicalOperator *last_op,
const std::unordered_set<Symbol> &bound_symbols,
std::vector<Filters::FilterInfo> &all_filters,
AstTreeStorage &storage) {
auto *filter_expr = ExtractFilters(bound_symbols, all_filters, storage,
[](const auto &) { return true; });
if (filter_expr) {
last_op =
new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr);
}
return last_op;
}
LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
// `n` multiple 'k' times, we want to return 'k' results where the property
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
return GenReturnBody(input_op, advance_command, body, accumulate);
}
LogicalOperator *GenCreateForPattern(
Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto base = [&](NodeAtom *node) -> LogicalOperator * {
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
else
return input_op;
};
auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to CreateExpand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand node was already bound, then we need to indicate this,
// so that CreateExpand only creates an edge.
bool node_existing = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
node_existing = true;
}
if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
permanent_fail("Symbols used for created edges cannot be redeclared.");
}
return new CreateExpand(node, edge,
std::shared_ptr<LogicalOperator>(last_op),
input_symbol, node_existing);
};
return ReducePattern<LogicalOperator *>(pattern, base, collect);
}
// Generate an operator for a clause which writes to the database. If the clause
// isn't handled, returns nullptr.
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = dynamic_cast<Create *>(clause)) {
return GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
del->expressions_, del->detach_);
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
set->property_lookup_, set->expression_);
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->expression_, op);
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->labels_);
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
rem->property_lookup_);
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
}
return nullptr;
}
LogicalOperator *GenWith(With &with, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
with.where_);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
BindSymbol(bound_symbols, symbol);
}
return last_op;
}
} // namespace impl
// Analyzes the filter expression by collecting information on filtering labels
// and properties to be used with indexing. Note that all filters are never
// updated here, but only labels and properties are.
// and properties to be used with indexing. Note that `all_filters_` are never
// updated here, but only `label_filters_` and `property_filters_` are.
void Filters::AnalyzeFilter(Expression *expr, const SymbolTable &symbol_table) {
using Bound = ScanAllByLabelPropertyRange::Bound;
auto get_property_lookup = [](auto *maybe_lookup, auto *&prop_lookup,
@ -714,11 +879,11 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
collector.symbols_.insert(symbol); // PropertyLookup uses the symbol.
if (is_variable_path) {
all_filters_.emplace_back(
storage.Create<All>(identifier, atom->identifier_,
storage.Create<Where>(prop_equal)),
collector.symbols_);
FilterInfo{storage.Create<All>(identifier, atom->identifier_,
storage.Create<Where>(prop_equal)),
collector.symbols_, true});
} else {
all_filters_.emplace_back(prop_equal, collector.symbols_);
all_filters_.emplace_back(FilterInfo{prop_equal, collector.symbols_});
}
}
};
@ -729,9 +894,9 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
label_filters_[node_symbol].insert(node->labels_.begin(),
node->labels_.end());
// Create a LabelsTest and store it in all_filters_.
all_filters_.emplace_back(
all_filters_.emplace_back(FilterInfo{
storage.Create<LabelsTest>(node->identifier_, node->labels_),
std::unordered_set<Symbol>{node_symbol});
std::unordered_set<Symbol>{node_symbol}});
}
add_properties_filter(node);
};
@ -740,19 +905,19 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
if (!edge->edge_types_.empty()) {
if (edge->has_range_) {
// We need a new identifier and symbol for All.
auto *identifier = edge->identifier_->Clone(storage);
symbol_table[*identifier] =
symbol_table.CreateSymbol(identifier->name_, false);
auto *ident_in_all = edge->identifier_->Clone(storage);
symbol_table[*ident_in_all] =
symbol_table.CreateSymbol(ident_in_all->name_, false);
auto *edge_type_test =
storage.Create<EdgeTypeTest>(identifier, edge->edge_types_);
all_filters_.emplace_back(
storage.Create<All>(identifier, edge->identifier_,
storage.Create<EdgeTypeTest>(ident_in_all, edge->edge_types_);
all_filters_.emplace_back(FilterInfo{
storage.Create<All>(ident_in_all, edge->identifier_,
storage.Create<Where>(edge_type_test)),
std::unordered_set<Symbol>{edge_symbol});
std::unordered_set<Symbol>{edge_symbol}, true});
} else {
all_filters_.emplace_back(
all_filters_.emplace_back(FilterInfo{
storage.Create<EdgeTypeTest>(edge->identifier_, edge->edge_types_),
std::unordered_set<Symbol>{edge_symbol});
std::unordered_set<Symbol>{edge_symbol}});
}
}
add_properties_filter(edge, edge->has_range_);
@ -761,13 +926,13 @@ void Filters::CollectPatternFilters(Pattern &pattern, SymbolTable &symbol_table,
ForEachPattern(pattern, add_node_filter, add_expand_filter);
}
// Adds the where filter expression to all filters and collects additional
// Adds the where filter expression to `all_filters_` and collects additional
// information for potential property and label indexing.
void Filters::CollectWhereFilter(Where &where,
const SymbolTable &symbol_table) {
UsedSymbolsCollector collector(symbol_table);
where.expression_->Accept(collector);
all_filters_.emplace_back(where.expression_, collector.symbols_);
all_filters_.emplace_back(FilterInfo{where.expression_, collector.symbols_});
AnalyzeFilter(where.expression_, symbol_table);
}
@ -809,144 +974,4 @@ std::vector<QueryPart> CollectQueryParts(SymbolTable &symbol_table,
return query_parts;
}
namespace impl {
// Returns false if the symbol was already bound, otherwise binds it and
// returns true.
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol) {
auto insertion = bound_symbols.insert(symbol);
return insertion.second;
}
LogicalOperator *GenFilters(
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
&all_filters,
AstTreeStorage &storage) {
Expression *filter_expr = nullptr;
for (auto filters_it = all_filters.begin();
filters_it != all_filters.end();) {
if (HasBoundFilterSymbols(bound_symbols, *filters_it)) {
filter_expr =
BoolJoin<FilterAndOperator>(storage, filter_expr, filters_it->first);
filters_it = all_filters.erase(filters_it);
} else {
filters_it++;
}
}
if (filter_expr) {
last_op =
new Filter(std::shared_ptr<LogicalOperator>(last_op), filter_expr);
}
return last_op;
}
LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
const std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// Similar to WITH clause, but we want to accumulate and advance command when
// the query writes to the database. This way we handle the case when we want
// to return expressions with the latest updated results. For example,
// `MATCH (n) -- () SET n.prop = n.prop + 1 RETURN n.prop`. If we match same
// `n` multiple 'k' times, we want to return 'k' results where the property
// value is the same, final result of 'k' increments.
bool accumulate = is_write;
bool advance_command = false;
ReturnBodyContext body(ret.body_, symbol_table, bound_symbols, storage);
return GenReturnBody(input_op, advance_command, body, accumulate);
}
LogicalOperator *GenCreateForPattern(
Pattern &pattern, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
auto base = [&](NodeAtom *node) -> LogicalOperator * {
if (BindSymbol(bound_symbols, symbol_table.at(*node->identifier_)))
return new CreateNode(node, std::shared_ptr<LogicalOperator>(input_op));
else
return input_op;
};
auto collect = [&](LogicalOperator *last_op, NodeAtom *prev_node,
EdgeAtom *edge, NodeAtom *node) {
// Store the symbol from the first node as the input to CreateExpand.
const auto &input_symbol = symbol_table.at(*prev_node->identifier_);
// If the expand node was already bound, then we need to indicate this,
// so that CreateExpand only creates an edge.
bool node_existing = false;
if (!BindSymbol(bound_symbols, symbol_table.at(*node->identifier_))) {
node_existing = true;
}
if (!BindSymbol(bound_symbols, symbol_table.at(*edge->identifier_))) {
permanent_fail("Symbols used for created edges cannot be redeclared.");
}
return new CreateExpand(node, edge,
std::shared_ptr<LogicalOperator>(last_op),
input_symbol, node_existing);
};
return ReducePattern<LogicalOperator *>(pattern, base, collect);
}
// Generate an operator for a clause which writes to the database. If the clause
// isn't handled, returns nullptr.
LogicalOperator *HandleWriteClause(Clause *clause, LogicalOperator *input_op,
const SymbolTable &symbol_table,
std::unordered_set<Symbol> &bound_symbols) {
if (auto *create = dynamic_cast<Create *>(clause)) {
return GenCreate(*create, input_op, symbol_table, bound_symbols);
} else if (auto *del = dynamic_cast<query::Delete *>(clause)) {
return new plan::Delete(std::shared_ptr<LogicalOperator>(input_op),
del->expressions_, del->detach_);
} else if (auto *set = dynamic_cast<query::SetProperty *>(clause)) {
return new plan::SetProperty(std::shared_ptr<LogicalOperator>(input_op),
set->property_lookup_, set->expression_);
} else if (auto *set = dynamic_cast<query::SetProperties *>(clause)) {
auto op = set->update_ ? plan::SetProperties::Op::UPDATE
: plan::SetProperties::Op::REPLACE;
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetProperties(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->expression_, op);
} else if (auto *set = dynamic_cast<query::SetLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*set->identifier_);
return new plan::SetLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, set->labels_);
} else if (auto *rem = dynamic_cast<query::RemoveProperty *>(clause)) {
return new plan::RemoveProperty(std::shared_ptr<LogicalOperator>(input_op),
rem->property_lookup_);
} else if (auto *rem = dynamic_cast<query::RemoveLabels *>(clause)) {
const auto &input_symbol = symbol_table.at(*rem->identifier_);
return new plan::RemoveLabels(std::shared_ptr<LogicalOperator>(input_op),
input_symbol, rem->labels_);
}
return nullptr;
}
LogicalOperator *GenWith(With &with, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
std::unordered_set<Symbol> &bound_symbols,
AstTreeStorage &storage) {
// WITH clause is Accumulate/Aggregate (advance_command) + Produce and
// optional Filter. In case of update and aggregation, we want to accumulate
// first, so that when aggregating, we get the latest results. Similar to
// RETURN clause.
bool accumulate = is_write;
// No need to advance the command if we only performed reads.
bool advance_command = is_write;
ReturnBodyContext body(with.body_, symbol_table, bound_symbols, storage,
with.where_);
LogicalOperator *last_op =
GenReturnBody(input_op, advance_command, body, accumulate);
// Reset bound symbols, so that only those in WITH are exposed.
bound_symbols.clear();
for (const auto &symbol : body.output_symbols()) {
BindSymbol(bound_symbols, symbol);
}
return last_op;
}
} // namespace impl
} // namespace query::plan

View File

@ -38,7 +38,19 @@ class Filters {
std::experimental::optional<Bound> upper_bound{};
};
/// All filter expressions that should be generated.
/// Stores additional information for a filter expression.
struct FilterInfo {
/// The filter expression which must be satisfied.
Expression *expression;
/// Set of used symbols by the filter @c expression.
std::unordered_set<Symbol> used_symbols;
/// True if the filter is to be applied on multiple expanding edges.
/// This is used to inline filtering in an @c ExpandVariable operator.
bool is_for_expand_variable = false;
};
/// List of FilterInfo objects corresponding to all filter expressions that
/// should be generated.
auto &all_filters() { return all_filters_; }
const auto &all_filters() const { return all_filters_; }
/// Mapping from a symbol to labels that are filtered on it. These should be
@ -66,7 +78,7 @@ class Filters {
private:
void AnalyzeFilter(Expression *, const SymbolTable &);
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>> all_filters_;
std::vector<FilterInfo> all_filters_;
std::unordered_map<Symbol, std::set<GraphDbTypes::Label>> label_filters_;
std::unordered_map<
Symbol, std::map<GraphDbTypes::Property, std::vector<PropertyFilter>>>
@ -190,11 +202,20 @@ namespace impl {
bool BindSymbol(std::unordered_set<Symbol> &bound_symbols,
const Symbol &symbol);
LogicalOperator *GenFilters(
LogicalOperator *last_op, const std::unordered_set<Symbol> &bound_symbols,
std::vector<std::pair<Expression *, std::unordered_set<Symbol>>>
&all_filters,
AstTreeStorage &storage);
// Looks for filter expressions, which can be inlined in an ExpandVariable
// operator. Such expressions are merged into one (via `and`) and removed from
// `all_filters`. If the expression uses `expands_to_node`, it is skipped. In
// such a case, we cannot cut variable expand short, since filtering may be
// satisfied by a node deeper in the path.
Expression *FindExpandVariableFilter(
const std::unordered_set<Symbol> &bound_symbols,
const Symbol &expands_to_node,
std::vector<Filters::FilterInfo> &all_filters, AstTreeStorage &storage);
LogicalOperator *GenFilters(LogicalOperator *last_op,
const std::unordered_set<Symbol> &bound_symbols,
std::vector<Filters::FilterInfo> &all_filters,
AstTreeStorage &storage);
LogicalOperator *GenReturn(Return &ret, LogicalOperator *input_op,
SymbolTable &symbol_table, bool is_write,
@ -464,12 +485,15 @@ class RuleBasedPlanner {
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, match_context.graph_view);
} else if (expansion.edge->has_range_) {
auto *filter_expr = impl::FindExpandVariableFilter(
bound_symbols, node_symbol, all_filters, storage);
last_op = new ExpandVariable(
node_symbol, edge_symbol, expansion.direction,
expansion.direction != expansion.edge->direction_,
expansion.edge->lower_bound_, expansion.edge->upper_bound_,
std::shared_ptr<LogicalOperator>(last_op), node1_symbol,
existing_node, existing_edge, match_context.graph_view);
existing_node, existing_edge, match_context.graph_view,
filter_expr);
} else {
last_op = new Expand(node_symbol, edge_symbol, expansion.direction,
std::shared_ptr<LogicalOperator>(last_op),

View File

@ -1252,7 +1252,7 @@ TEST(TestLogicalPlanner, MatchExpandVariableNoBounds) {
CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectProduce());
}
TEST(TestLogicalPlanner, MatchExpandVariableFiltered) {
TEST(TestLogicalPlanner, MatchExpandVariableInlinedFilter) {
// Test MATCH (n) -[r :type * {prop: 42}]-> (m) RETURN r
Dbms dbms;
auto dba = dbms.active();
@ -1263,6 +1263,22 @@ TEST(TestLogicalPlanner, MatchExpandVariableFiltered) {
edge->has_range_ = true;
edge->properties_[prop] = LITERAL(42);
QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"));
CheckPlan(storage, ExpectScanAll(),
ExpectExpandVariable(), // Filter is inlined in expand
ExpectProduce());
}
TEST(TestLogicalPlanner, MatchExpandVariableNotInlinedFilter) {
// Test MATCH (n) -[r :type * {prop: m.prop}]-> (m) RETURN r
Dbms dbms;
auto dba = dbms.active();
auto type = dba->EdgeType("type");
auto prop = PROPERTY_PAIR("prop");
AstTreeStorage storage;
auto edge = EDGE("r", type);
edge->has_range_ = true;
edge->properties_[prop] = EQ(PROPERTY_LOOKUP("m", prop), LITERAL(42));
QUERY(MATCH(PATTERN(NODE("n"), edge, NODE("m"))), RETURN("r"));
CheckPlan(storage, ExpectScanAll(), ExpectExpandVariable(), ExpectFilter(),
ExpectProduce());
}