Query::Plan::Accumulate[Advance] implementation and test

Summary: see above

Reviewers: teon.banek, buda, mislav.bradac

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D226
This commit is contained in:
florijan 2017-04-05 15:00:26 +02:00
parent 5243ab00c2
commit 4c73a0a71c
4 changed files with 192 additions and 28 deletions

View File

@ -7,7 +7,7 @@
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "utils/assert.hpp"
#include <utils/exceptions/not_yet_implemented.hpp>
#include "utils/exceptions/not_yet_implemented.hpp"
namespace query {

View File

@ -185,8 +185,7 @@ bool ScanAll::ScanAllCursor::Pull(Frame &frame, SymbolTable &symbol_table) {
// if we have no more vertices, we're done (if input_ is set we have
// just tried to re-init vertices_it_, and if not we only iterate
// through it once
if (vertices_it_ == vertices_.end())
return false;
if (vertices_it_ == vertices_.end()) return false;
frame[symbol_table[*self_.node_atom_->identifier_]] = *vertices_it_++;
return true;
@ -791,7 +790,8 @@ ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilter(
previous_symbols_(previous_symbols) {}
template <typename TAccessor>
void ExpandUniquenessFilter<TAccessor>::Accept(LogicalOperatorVisitor &visitor) {
void ExpandUniquenessFilter<TAccessor>::Accept(
LogicalOperatorVisitor &visitor) {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
@ -812,7 +812,6 @@ ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::
template <typename TAccessor>
bool ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::Pull(
Frame &frame, SymbolTable &symbol_table) {
auto expansion_ok = [&]() {
TypedValue &expand_value = frame[self_.expand_symbol_];
TAccessor &expand_accessor = expand_value.Value<TAccessor>();
@ -825,8 +824,7 @@ bool ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::Pull(
};
while (input_cursor_->Pull(frame, symbol_table))
if (expansion_ok())
return true;
if (expansion_ok()) return true;
return false;
}
@ -835,8 +833,38 @@ bool ExpandUniquenessFilter<TAccessor>::ExpandUniquenessFilterCursor::Pull(
template class ExpandUniquenessFilter<VertexAccessor>;
template class ExpandUniquenessFilter<EdgeAccessor>;
Accumulate::Accumulate(std::shared_ptr<LogicalOperator> input, const std::vector<Symbol> &symbols) :
input_(input), symbols_(symbols) {}
namespace {
/**
* Helper function for recursively reconstructing all the accessors in the
* given TypedValue.
*/
void ReconstructTypedValue(TypedValue &value) {
switch (value.type()) {
case TypedValue::Type::Vertex:
value.Value<VertexAccessor>().Reconstruct();
break;
case TypedValue::Type::Edge:
value.Value<EdgeAccessor>().Reconstruct();
break;
case TypedValue::Type::List:
for (TypedValue &inner_value : value.Value<std::vector<TypedValue>>())
ReconstructTypedValue(inner_value);
break;
case TypedValue::Type::Map:
for (auto &kv : value.Value<std::map<std::string, TypedValue>>())
ReconstructTypedValue(kv.second);
break;
default:
break;
// TODO implement path reconstruct?
}
}
}
Accumulate::Accumulate(std::shared_ptr<LogicalOperator> input,
const std::vector<Symbol> &symbols, bool advance_command)
: input_(input), symbols_(symbols), advance_command_(advance_command) {}
void Accumulate::Accept(LogicalOperatorVisitor &visitor) {
visitor.Visit(*this);
@ -844,18 +872,37 @@ void Accumulate::Accept(LogicalOperatorVisitor &visitor) {
visitor.PostVisit(*this);
}
std::unique_ptr<Cursor> Accumulate::MakeCursor(GraphDbAccessor &db) {
return std::unique_ptr<Cursor>();
return std::make_unique<Accumulate::AccumulateCursor>(*this, db);
}
AdvanceCommand::AdvanceCommand(std::shared_ptr<LogicalOperator> input) : input_(input) {}
Accumulate::AccumulateCursor::AccumulateCursor(Accumulate &self,
GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self.input_->MakeCursor(db)) {}
void AdvanceCommand::Accept(LogicalOperatorVisitor &visitor) {
visitor.Visit(*this);
input_->Accept(visitor);
visitor.PostVisit(*this);
}
std::unique_ptr<Cursor> AdvanceCommand::MakeCursor(GraphDbAccessor &db) {
return std::unique_ptr<Cursor>();
bool Accumulate::AccumulateCursor::Pull(Frame &frame,
SymbolTable &symbol_table) {
// cache all the input
if (!pulled_all_input_) {
while (input_cursor_->Pull(frame, symbol_table)) {
cache_.emplace_back();
auto &row = cache_.back();
for (const Symbol &symbol : self_.symbols_)
row.emplace_back(frame[symbol]);
}
pulled_all_input_ = true;
cache_it_ = cache_.begin();
if (self_.advance_command_) {
db_.advance_command();
for (auto &row : cache_)
for (auto &col : row) ReconstructTypedValue(col);
}
}
if (cache_it_ == cache_.end()) return false;
auto row_it = (cache_it_++)->begin();
for (const Symbol &symbol : self_.symbols_) frame[symbol] = *row_it++;
return true;
}
} // namespace plan

View File

@ -742,25 +742,57 @@ class ExpandUniquenessFilter : public LogicalOperator {
};
};
/** @brief Pulls everything from the input before passing it through.
* Optionally advances the command after accumulation and before emitting.
*
* On the first Pull from this Op's Cursor the input Cursor will be
* Pulled until it is empty. The results will be accumulated in the
* temporary cache. Once the input Cursor is empty, this Op's Cursor
* will start returning cached stuff from it's Pull.
*
* This technique is used for ensuring all the operations from the
* previous LogicalOp have been performed before exposing data
* to the next. A typical use-case is the `MATCH - SET - RETURN`
* query in which every SET iteration must be performed before
* RETURN starts iterating (see Memgraph Wiki for detailed reasoning).
*
* IMPORTANT: This Op does not cache all the results but only those
* elements from the frame whose symbols (frame positions) it was given.
* All other frame positions will contain undefined junk after this
* op has executed, and should not be used.
*
* This op can also advance the command after the accumulation and
* before emitting. If the command gets advanced, every value that
* has been cached will be reconstructed before Pull returns.
*
* @param input Input @c LogicalOperator.
* @param symbols A vector of Symbols that need to be accumulated
* and exposed to the next op.
*/
class Accumulate : public LogicalOperator {
public:
Accumulate(std::shared_ptr<LogicalOperator> input, const std::vector<Symbol> &symbols);
Accumulate(std::shared_ptr<LogicalOperator> input, const std::vector<Symbol> &symbols,
bool advance_command=false);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
std::shared_ptr<LogicalOperator> input_;
const std::vector<Symbol> symbols_;
};
bool advance_command_{false};
class AdvanceCommand : public LogicalOperator {
public:
AdvanceCommand(std::shared_ptr<LogicalOperator> input);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
std::shared_ptr<LogicalOperator> input_;
class AccumulateCursor : public Cursor {
public:
AccumulateCursor(Accumulate &self, GraphDbAccessor &db);
bool Pull(Frame &frame, SymbolTable &symbol_table) override;
private:
Accumulate &self_;
GraphDbAccessor &db_;
std::unique_ptr<Cursor> input_cursor_;
std::list<std::list<TypedValue>> cache_;
decltype(cache_.begin()) cache_it_ = cache_.begin();
bool pulled_all_input_{false};
};
};
} // namespace plan

View File

@ -1265,3 +1265,88 @@ TEST(Interpreter, ExpandUniquenessFilter) {
EXPECT_EQ(0, check_expand_results(true, false));
EXPECT_EQ(1, check_expand_results(false, true));
}
TEST(Interpreter, Accumulate) {
// simulate the following two query execution on an empty db
// CREATE ({x:0})-[:T]->({x:0})
// MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x
// without accumulation we expected results to be [[1, 1], [2, 2]]
// with accumulation we expect them to be [[2, 2], [2, 2]]
auto check = [&](bool accumulate) {
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("x");
auto v1 = dba->insert_vertex();
v1.PropsSet(prop, 0);
auto v2 = dba->insert_vertex();
v2.PropsSet(prop, 0);
dba->insert_edge(v1, v2, dba->edge_type("T"));
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::BOTH, false, "m", false);
auto one = LITERAL(1);
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
auto set_n_p =
std::make_shared<plan::SetProperty>(r_m.op_, n_p, ADD(n_p, one));
auto m_p = PROPERTY_LOOKUP("m", prop);
symbol_table[*m_p->expression_] = r_m.node_sym_;
auto set_m_p =
std::make_shared<plan::SetProperty>(set_n_p, m_p, ADD(m_p, one));
std::shared_ptr<LogicalOperator> last_op = set_m_p;
if (accumulate) {
last_op = std::make_shared<Accumulate>(
last_op, std::vector<Symbol>{n.sym_, r_m.node_sym_});
}
auto n_p_ne = NEXPR("n.p", n_p);
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne");
auto m_p_ne = NEXPR("m.p", m_p);
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne");
auto produce = MakeProduce(last_op, n_p_ne, m_p_ne);
ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba);
std::vector<int> results_data;
for (const auto &row : results.GetResults())
for (const auto &column : row)
results_data.emplace_back(column.Value<int64_t>());
if (accumulate)
EXPECT_THAT(results_data, testing::ElementsAre(2, 2, 2, 2));
else
EXPECT_THAT(results_data, testing::ElementsAre(1, 1, 2, 2));
};
check(false);
check(true);
}
TEST(Interpreter, AccumulateAdvance) {
// we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m'
// to get correct results we need to advance the command
auto check = [&](bool advance) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
symbol_table[*node->identifier_] = sym_n;
auto create = std::make_shared<CreateNode>(node, nullptr);
auto accumulate = std::make_shared<Accumulate>(
create, std::vector<Symbol>{sym_n}, advance);
auto match = MakeScanAll(storage, symbol_table, "m", accumulate);
EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table));
};
check(false);
check(true);
}