Query::Plan::Accumulate[Advance] implementation and test
Summary: see above Reviewers: teon.banek, buda, mislav.bradac Reviewed By: teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D226
This commit is contained in:
parent
5243ab00c2
commit
4c73a0a71c
@ -7,7 +7,7 @@
|
||||
#include "query/frontend/ast/ast.hpp"
|
||||
#include "query/frontend/semantic/symbol_table.hpp"
|
||||
#include "utils/assert.hpp"
|
||||
#include <utils/exceptions/not_yet_implemented.hpp>
|
||||
#include "utils/exceptions/not_yet_implemented.hpp"
|
||||
|
||||
namespace query {
|
||||
|
||||
|
@ -185,8 +185,7 @@ bool ScanAll::ScanAllCursor::Pull(Frame &frame, SymbolTable &symbol_table) {
|
||||
// if we have no more vertices, we're done (if input_ is set we have
|
||||
// just tried to re-init vertices_it_, and if not we only iterate
|
||||
// through it once
|
||||
if (vertices_it_ == vertices_.end())
|
||||
return false;
|
||||
if (vertices_it_ == vertices_.end()) return false;
|
||||
|
||||
frame[symbol_table[*self_.node_atom_->identifier_]] = *vertices_it_++;
|
||||
return true;
|
||||
@ -791,7 +790,8 @@ ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilter(
|
||||
previous_symbols_(previous_symbols) {}
|
||||
|
||||
template <typename TAccessor>
|
||||
void ExpandUniquenessFilter<TAccessor>::Accept(LogicalOperatorVisitor &visitor) {
|
||||
void ExpandUniquenessFilter<TAccessor>::Accept(
|
||||
LogicalOperatorVisitor &visitor) {
|
||||
visitor.Visit(*this);
|
||||
input_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
@ -812,7 +812,6 @@ ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::
|
||||
template <typename TAccessor>
|
||||
bool ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::Pull(
|
||||
Frame &frame, SymbolTable &symbol_table) {
|
||||
|
||||
auto expansion_ok = [&]() {
|
||||
TypedValue &expand_value = frame[self_.expand_symbol_];
|
||||
TAccessor &expand_accessor = expand_value.Value<TAccessor>();
|
||||
@ -825,8 +824,7 @@ bool ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::Pull(
|
||||
};
|
||||
|
||||
while (input_cursor_->Pull(frame, symbol_table))
|
||||
if (expansion_ok())
|
||||
return true;
|
||||
if (expansion_ok()) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -835,8 +833,38 @@ bool ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::Pull(
|
||||
template class ExpandUniquenessFilter<VertexAccessor>;
|
||||
template class ExpandUniquenessFilter<EdgeAccessor>;
|
||||
|
||||
Accumulate::Accumulate(std::shared_ptr<LogicalOperator> input, const std::vector<Symbol> &symbols) :
|
||||
input_(input), symbols_(symbols) {}
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* Helper function for recursively reconstructing all the accessors in the
|
||||
* given TypedValue.
|
||||
*/
|
||||
void ReconstructTypedValue(TypedValue &value) {
|
||||
switch (value.type()) {
|
||||
case TypedValue::Type::Vertex:
|
||||
value.Value<VertexAccessor>().Reconstruct();
|
||||
break;
|
||||
case TypedValue::Type::Edge:
|
||||
value.Value<EdgeAccessor>().Reconstruct();
|
||||
break;
|
||||
case TypedValue::Type::List:
|
||||
for (TypedValue &inner_value : value.Value<std::vector<TypedValue>>())
|
||||
ReconstructTypedValue(inner_value);
|
||||
break;
|
||||
case TypedValue::Type::Map:
|
||||
for (auto &kv : value.Value<std::map<std::string, TypedValue>>())
|
||||
ReconstructTypedValue(kv.second);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
// TODO implement path reconstruct?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Accumulate::Accumulate(std::shared_ptr<LogicalOperator> input,
|
||||
const std::vector<Symbol> &symbols, bool advance_command)
|
||||
: input_(input), symbols_(symbols), advance_command_(advance_command) {}
|
||||
|
||||
void Accumulate::Accept(LogicalOperatorVisitor &visitor) {
|
||||
visitor.Visit(*this);
|
||||
@ -844,18 +872,37 @@ void Accumulate::Accept(LogicalOperatorVisitor &visitor) {
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
std::unique_ptr<Cursor> Accumulate::MakeCursor(GraphDbAccessor &db) {
|
||||
return std::unique_ptr<Cursor>();
|
||||
return std::make_unique<Accumulate::AccumulateCursor>(*this, db);
|
||||
}
|
||||
|
||||
AdvanceCommand::AdvanceCommand(std::shared_ptr<LogicalOperator> input) : input_(input) {}
|
||||
Accumulate::AccumulateCursor::AccumulateCursor(Accumulate &self,
|
||||
GraphDbAccessor &db)
|
||||
: self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {}
|
||||
|
||||
void AdvanceCommand::Accept(LogicalOperatorVisitor &visitor) {
|
||||
visitor.Visit(*this);
|
||||
input_->Accept(visitor);
|
||||
visitor.PostVisit(*this);
|
||||
}
|
||||
std::unique_ptr<Cursor> AdvanceCommand::MakeCursor(GraphDbAccessor &db) {
|
||||
return std::unique_ptr<Cursor>();
|
||||
bool Accumulate::AccumulateCursor::Pull(Frame &frame,
|
||||
SymbolTable &symbol_table) {
|
||||
// cache all the input
|
||||
if (!pulled_all_input_) {
|
||||
while (input_cursor_->Pull(frame, symbol_table)) {
|
||||
cache_.emplace_back();
|
||||
auto &row = cache_.back();
|
||||
for (const Symbol &symbol : self_.symbols_)
|
||||
row.emplace_back(frame[symbol]);
|
||||
}
|
||||
pulled_all_input_ = true;
|
||||
cache_it_ = cache_.begin();
|
||||
|
||||
if (self_.advance_command_) {
|
||||
db_.advance_command();
|
||||
for (auto &row : cache_)
|
||||
for (auto &col : row) ReconstructTypedValue(col);
|
||||
}
|
||||
}
|
||||
|
||||
if (cache_it_ == cache_.end()) return false;
|
||||
auto row_it = (cache_it_++)->begin();
|
||||
for (const Symbol &symbol : self_.symbols_) frame[symbol] = *row_it++;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace plan
|
||||
|
@ -742,25 +742,57 @@ class ExpandUniquenessFilter : public LogicalOperator {
|
||||
};
|
||||
};
|
||||
|
||||
/** @brief Pulls everything from the input before passing it through.
|
||||
* Optionally advances the command after accumulation and before emitting.
|
||||
*
|
||||
* On the first Pull from this Op's Cursor the input Cursor will be
|
||||
* Pulled until it is empty. The results will be accumulated in the
|
||||
* temporary cache. Once the input Cursor is empty, this Op's Cursor
|
||||
* will start returning cached stuff from it's Pull.
|
||||
*
|
||||
* This technique is used for ensuring all the operations from the
|
||||
* previous LogicalOp have been performed before exposing data
|
||||
* to the next. A typical use-case is the `MATCH - SET - RETURN`
|
||||
* query in which every SET iteration must be performed before
|
||||
* RETURN starts iterating (see Memgraph Wiki for detailed reasoning).
|
||||
*
|
||||
* IMPORTANT: This Op does not cache all the results but only those
|
||||
* elements from the frame whose symbols (frame positions) it was given.
|
||||
* All other frame positions will contain undefined junk after this
|
||||
* op has executed, and should not be used.
|
||||
*
|
||||
* This op can also advance the command after the accumulation and
|
||||
* before emitting. If the command gets advanced, every value that
|
||||
* has been cached will be reconstructed before Pull returns.
|
||||
*
|
||||
* @param input Input @c LogicalOperator.
|
||||
* @param symbols A vector of Symbols that need to be accumulated
|
||||
* and exposed to the next op.
|
||||
*/
|
||||
class Accumulate : public LogicalOperator {
|
||||
public:
|
||||
Accumulate(std::shared_ptr<LogicalOperator> input, const std::vector<Symbol> &symbols);
|
||||
Accumulate(std::shared_ptr<LogicalOperator> input, const std::vector<Symbol> &symbols,
|
||||
bool advance_command=false);
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<LogicalOperator> input_;
|
||||
const std::vector<Symbol> symbols_;
|
||||
};
|
||||
bool advance_command_{false};
|
||||
|
||||
class AdvanceCommand : public LogicalOperator {
|
||||
public:
|
||||
AdvanceCommand(std::shared_ptr<LogicalOperator> input);
|
||||
void Accept(LogicalOperatorVisitor &visitor) override;
|
||||
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<LogicalOperator> input_;
|
||||
class AccumulateCursor : public Cursor {
|
||||
public:
|
||||
AccumulateCursor(Accumulate &self, GraphDbAccessor &db);
|
||||
bool Pull(Frame &frame, SymbolTable &symbol_table) override;
|
||||
private:
|
||||
Accumulate &self_;
|
||||
GraphDbAccessor &db_;
|
||||
std::unique_ptr<Cursor> input_cursor_;
|
||||
std::list<std::list<TypedValue>> cache_;
|
||||
decltype(cache_.begin()) cache_it_ = cache_.begin();
|
||||
bool pulled_all_input_{false};
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace plan
|
||||
|
@ -1265,3 +1265,88 @@ TEST(Interpreter, ExpandUniquenessFilter) {
|
||||
EXPECT_EQ(0, check_expand_results(true, false));
|
||||
EXPECT_EQ(1, check_expand_results(false, true));
|
||||
}
|
||||
|
||||
TEST(Interpreter, Accumulate) {
|
||||
// simulate the following two query execution on an empty db
|
||||
// CREATE ({x:0})-[:T]->({x:0})
|
||||
// MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x
|
||||
// without accumulation we expected results to be [[1, 1], [2, 2]]
|
||||
// with accumulation we expect them to be [[2, 2], [2, 2]]
|
||||
|
||||
auto check = [&](bool accumulate) {
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
auto prop = dba->property("x");
|
||||
|
||||
auto v1 = dba->insert_vertex();
|
||||
v1.PropsSet(prop, 0);
|
||||
auto v2 = dba->insert_vertex();
|
||||
v2.PropsSet(prop, 0);
|
||||
dba->insert_edge(v1, v2, dba->edge_type("T"));
|
||||
dba->advance_command();
|
||||
|
||||
AstTreeStorage storage;
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto n = MakeScanAll(storage, symbol_table, "n");
|
||||
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
|
||||
EdgeAtom::Direction::BOTH, false, "m", false);
|
||||
|
||||
auto one = LITERAL(1);
|
||||
auto n_p = PROPERTY_LOOKUP("n", prop);
|
||||
symbol_table[*n_p->expression_] = n.sym_;
|
||||
auto set_n_p =
|
||||
std::make_shared<plan::SetProperty>(r_m.op_, n_p, ADD(n_p, one));
|
||||
auto m_p = PROPERTY_LOOKUP("m", prop);
|
||||
symbol_table[*m_p->expression_] = r_m.node_sym_;
|
||||
auto set_m_p =
|
||||
std::make_shared<plan::SetProperty>(set_n_p, m_p, ADD(m_p, one));
|
||||
|
||||
std::shared_ptr<LogicalOperator> last_op = set_m_p;
|
||||
if (accumulate) {
|
||||
last_op = std::make_shared<Accumulate>(
|
||||
last_op, std::vector<Symbol>{n.sym_, r_m.node_sym_});
|
||||
}
|
||||
|
||||
auto n_p_ne = NEXPR("n.p", n_p);
|
||||
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne");
|
||||
auto m_p_ne = NEXPR("m.p", m_p);
|
||||
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne");
|
||||
auto produce = MakeProduce(last_op, n_p_ne, m_p_ne);
|
||||
ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba);
|
||||
std::vector<int> results_data;
|
||||
for (const auto &row : results.GetResults())
|
||||
for (const auto &column : row)
|
||||
results_data.emplace_back(column.Value<int64_t>());
|
||||
if (accumulate)
|
||||
EXPECT_THAT(results_data, testing::ElementsAre(2, 2, 2, 2));
|
||||
else
|
||||
EXPECT_THAT(results_data, testing::ElementsAre(1, 1, 2, 2));
|
||||
};
|
||||
|
||||
check(false);
|
||||
check(true);
|
||||
}
|
||||
|
||||
TEST(Interpreter, AccumulateAdvance) {
|
||||
// we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m'
|
||||
// to get correct results we need to advance the command
|
||||
|
||||
auto check = [&](bool advance) {
|
||||
Dbms dbms;
|
||||
auto dba = dbms.active();
|
||||
AstTreeStorage storage;
|
||||
SymbolTable symbol_table;
|
||||
|
||||
auto node = NODE("n");
|
||||
auto sym_n = symbol_table.CreateSymbol("n");
|
||||
symbol_table[*node->identifier_] = sym_n;
|
||||
auto create = std::make_shared<CreateNode>(node, nullptr);
|
||||
auto accumulate = std::make_shared<Accumulate>(
|
||||
create, std::vector<Symbol>{sym_n}, advance);
|
||||
auto match = MakeScanAll(storage, symbol_table, "m", accumulate);
|
||||
EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table));
|
||||
};
|
||||
check(false);
|
||||
check(true);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user