Query::Plan::Aggregate

Summary:
- Aggregation LogicalOperation added, with tests.

- Added capabilities to TypedValue (hash, bool-equality)
to support std::unordered_map<TypedValue>.

- Removed some bad code from utils/hashing/fnv and added
a hashing function for collections.

Reviewers: buda, mislav.bradac, teon.banek

Reviewed By: teon.banek

Subscribers: lion, pullbot

Differential Revision: https://phabricator.memgraph.io/D252
This commit is contained in:
florijan 2017-04-11 15:11:48 +02:00
parent dfa6800edd
commit 593e4e72b9
15 changed files with 1437 additions and 750 deletions

View File

@ -146,7 +146,7 @@ class ExpressionEvaluator : public TreeVisitorBase {
}
void Visit(Aggregation &aggregation) override {
auto value = frame_[symbol_table_[aggregation]];
auto value = frame_[symbol_table_.at(aggregation)];
// Aggregation is probably always simple type, but let's switch accessor
// just to be sure.
SwitchAccessors(value);

View File

@ -1,3 +1,5 @@
#include <algorithm>
#include "query/frontend/logical/operator.hpp"
#include "query/exceptions.hpp"
@ -285,7 +287,6 @@ bool Expand::ExpandCursor::InitEdges(Frame &frame,
// will need it). For now only Back expansion (left to right) is
// supported
// TODO add support for named paths
// TODO add support for uniqueness (edge, vertex)
return true;
}
@ -892,9 +893,11 @@ void ReconstructTypedValue(TypedValue &value) {
for (auto &kv : value.Value<std::map<std::string, TypedValue>>())
ReconstructTypedValue(kv.second);
break;
case TypedValue::Type::Path:
// TODO implement path reconstruct?
throw NotYetImplemented("Path reconstruction not yet supported");
default:
break;
// TODO implement path reconstruct?
}
}
}
@ -946,8 +949,13 @@ bool Accumulate::AccumulateCursor::Pull(Frame &frame,
Aggregate::Aggregate(const std::shared_ptr<LogicalOperator> &input,
const std::vector<Aggregate::Element> &aggregations,
const std::vector<NamedExpression *> group_by)
: input_(input), aggregations_(aggregations), group_by_(group_by) {}
const std::vector<Expression *> &group_by,
const std::vector<Symbol> &remember, bool advance_command)
: input_(input),
aggregations_(aggregations),
group_by_(group_by),
remember_(remember),
advance_command_(advance_command) {}
void Aggregate::Accept(LogicalOperatorVisitor &visitor) {
if (visitor.PreVisit(*this)) {
@ -958,7 +966,182 @@ void Aggregate::Accept(LogicalOperatorVisitor &visitor) {
}
std::unique_ptr<Cursor> Aggregate::MakeCursor(GraphDbAccessor &db) {
return std::unique_ptr<Cursor>();
return std::make_unique<AggregateCursor>(*this, db);
}
Aggregate::AggregateCursor::AggregateCursor(Aggregate &self,
GraphDbAccessor &db)
: self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {}
bool Aggregate::AggregateCursor::Pull(Frame &frame,
const SymbolTable &symbol_table) {
if (!pulled_all_input_) {
PullAllInput(frame, symbol_table);
pulled_all_input_ = true;
aggregation_it_ = aggregation_.begin();
if (self_.advance_command_) {
db_.advance_command();
// regarding reconstruction after advance_command
// we have to reconstruct only the remember values
// because aggregation results are primitives and
// group-by elements won't be used directly (possibly re-evaluated
// using remember values)
for (auto &kv : aggregation_)
for (TypedValue &remember : kv.second.remember_)
ReconstructTypedValue(remember);
}
}
if (aggregation_it_ == aggregation_.end()) return false;
// place aggregation values on the frame
auto aggregation_values_it = aggregation_it_->second.values_.begin();
for (const auto &aggregation_elem : self_.aggregations_)
frame[std::get<2>(aggregation_elem)] = *aggregation_values_it++;
// place remember values on the frame
auto remember_values_it = aggregation_it_->second.remember_.begin();
for (const Symbol &remember_sym : self_.remember_)
frame[remember_sym] = *remember_values_it++;
aggregation_it_++;
return true;
}
void Aggregate::AggregateCursor::PullAllInput(Frame &frame,
const SymbolTable &symbol_table) {
ExpressionEvaluator evaluator(frame, symbol_table);
evaluator.SwitchNew();
while (input_cursor_->Pull(frame, symbol_table)) {
// create the group-by list of values
std::list<TypedValue> group_by;
for (Expression *expression : self_.group_by_) {
expression->Accept(evaluator);
group_by.emplace_back(evaluator.PopBack());
}
AggregationValue &agg_value = aggregation_[group_by];
EnsureInitialized(frame, agg_value);
Update(frame, symbol_table, evaluator, agg_value);
}
// calculate AVG aggregations (so far they have only been summed)
for (int pos = 0; pos < self_.aggregations_.size(); ++pos) {
if (std::get<1>(self_.aggregations_[pos]) != Aggregation::Op::AVG) continue;
for (auto &kv : aggregation_) {
AggregationValue &agg_value = kv.second;
int count = agg_value.counts_[pos];
if (count > 0)
agg_value.values_[pos] = agg_value.values_[pos] / (double)count;
}
}
}
void Aggregate::AggregateCursor::EnsureInitialized(
Frame &frame, Aggregate::AggregateCursor::AggregationValue &agg_value) {
if (agg_value.values_.size() > 0) return;
agg_value.values_.resize(self_.aggregations_.size(), TypedValue::Null);
agg_value.counts_.resize(self_.aggregations_.size(), 0);
for (const Symbol &remember_sym : self_.remember_)
agg_value.remember_.push_back(frame[remember_sym]);
}
void Aggregate::AggregateCursor::Update(
Frame &frame, const SymbolTable &symbol_table,
ExpressionEvaluator &evaluator,
Aggregate::AggregateCursor::AggregationValue &agg_value) {
debug_assert(self_.aggregations_.size() == agg_value.values_.size(),
"Inappropriate AggregationValue.values_ size");
debug_assert(self_.aggregations_.size() == agg_value.counts_.size(),
"Inappropriate AggregationValue.counts_ size");
// we iterate over counts, values and aggregation info at the same time
auto count_it = agg_value.counts_.begin();
auto value_it = agg_value.values_.begin();
auto agg_elem_it = self_.aggregations_.begin();
for (; count_it < agg_value.counts_.end();
count_it++, value_it++, agg_elem_it++) {
std::get<0>(*agg_elem_it)->Accept(evaluator);
TypedValue input_value = evaluator.PopBack();
if (input_value.type() == TypedValue::Type::Null) continue;
*count_it += 1;
if (*count_it == 1) {
// first value, nothing to aggregate. set and continue.
*value_it = input_value;
continue;
}
// aggregation of existing values
switch (std::get<1>(*agg_elem_it)) {
case Aggregation::Op::COUNT:
*value_it = *count_it;
break;
case Aggregation::Op::MIN: {
EnsureOkForMinMax(input_value);
// TODO an illegal comparison here will throw a TypedValueException
// consider catching and throwing something else
TypedValue comparison_result = input_value < *value_it;
// since we skip nulls we either have a valid comparison, or
// an exception was just thrown above
// safe to assume a bool TypedValue
if (comparison_result.Value<bool>()) *value_it = input_value;
break;
}
case Aggregation::Op::MAX: {
// all comments as for Op::Min
EnsureOkForMinMax(input_value);
TypedValue comparison_result = input_value > *value_it;
if (comparison_result.Value<bool>()) *value_it = input_value;
break;
}
case Aggregation::Op::AVG:
// for averaging we sum first and divide by count once all
// the input has been processed
case Aggregation::Op::SUM:
EnsureOkForAvgSum(input_value);
*value_it = *value_it + input_value;
break;
} // end switch over Aggregation::Op enum
} // end loop over all aggregations
}
void Aggregate::AggregateCursor::EnsureOkForMinMax(const TypedValue &value) {
switch (value.type()) {
case TypedValue::Type::Bool:
case TypedValue::Type::Int:
case TypedValue::Type::Double:
case TypedValue::Type::String:
return;
default:
// TODO consider better error feedback
throw TypedValueException(
"Only Bool, Int, Double and String properties are allowed in "
"MIN and MAX aggregations");
}
}
void Aggregate::AggregateCursor::EnsureOkForAvgSum(const TypedValue &value) {
switch (value.type()) {
case TypedValue::Type::Int:
case TypedValue::Type::Double:
return;
default:
// TODO consider better error feedback
throw TypedValueException(
"Only numeric properties allowed in SUM and AVG aggregations");
}
}
bool Aggregate::AggregateCursor::TypedValueListEqual::operator()(
const std::list<TypedValue> &left,
const std::list<TypedValue> &right) const {
return std::equal(left.begin(), left.end(), right.begin(),
TypedValue::BoolEqual{});
}
} // namespace plan

View File

@ -2,12 +2,16 @@
#pragma once
#include <algorithm>
#include <memory>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "database/graph_db_accessor.hpp"
#include "database/graph_db_datatypes.hpp"
#include "query/frontend/semantic/symbol_table.hpp"
#include "utils/hashing/fnv.hpp"
#include "utils/visitor/visitable.hpp"
#include "utils/visitor/visitor.hpp"
@ -824,14 +828,94 @@ class Aggregate : public LogicalOperator {
Aggregate(const std::shared_ptr<LogicalOperator> &input,
const std::vector<Element> &aggregations,
const std::vector<NamedExpression *> group_by);
const std::vector<Expression *> &group_by,
const std::vector<Symbol> &remember, bool advance_command = false);
void Accept(LogicalOperatorVisitor &visitor) override;
std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor &db) override;
private:
const std::shared_ptr<LogicalOperator> input_;
const std::vector<Element> aggregations_;
const std::vector<NamedExpression *> group_by_;
const std::vector<Expression *> group_by_;
const std::vector<Symbol> remember_;
const bool advance_command_;
class AggregateCursor : public Cursor {
public:
AggregateCursor(Aggregate &self, GraphDbAccessor &db);
bool Pull(Frame &frame, const SymbolTable &symbol_table) override;
private:
// custom equality function for the unordered map
struct TypedValueListEqual {
bool operator()(const std::list<TypedValue> &left,
const std::list<TypedValue> &right) const;
};
// Data structure for a single aggregation cache.
// does NOT include the group-by values since those
// are a key in the aggregation map.
// The vectors in an AggregationValue contain one element
// for each aggregation in this LogicalOp.
struct AggregationValue {
// how many input rows has been aggregated in respective
// values_ element so far
std::vector<int> counts_;
// aggregated values. Initially Null (until at least one
// input row with a valid value gets processed)
std::vector<TypedValue> values_;
// remember values.
std::vector<TypedValue> remember_;
};
Aggregate &self_;
GraphDbAccessor &db_;
std::unique_ptr<Cursor> input_cursor_;
// storage for aggregated data
// map key is the list of group-by values
// map value is an AggregationValue struct
std::unordered_map<
std::list<TypedValue>, AggregationValue,
// use FNV collection hashing specialized for a list of TypedValues
FnvCollection<std::list<TypedValue>, TypedValue, TypedValue::Hash>,
// custom equality
TypedValueListEqual>
aggregation_;
// iterator over the accumulated cache
decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin();
// this LogicalOp pulls all from the input on it's first pull
// this switch tracks if this has been performed
bool pulled_all_input_{false};
/**
* Pulls from the input until it's exhausted. Accumulates the results
* in the `aggregation_` map.
*
* Accumulation automatically groups the results so that `aggregation_`
* cache cardinality depends on number of
* aggregation results, and not on the number of inputs.
*/
void PullAllInput(Frame &frame, const SymbolTable &symbolTable);
/** Ensures the new AggregationValue has been initialized. This means
* that the value vectors are filled with an appropriate number of Nulls,
* counts are set to 0 and remember values are remembered.
*/
void EnsureInitialized(Frame &frame, AggregationValue &agg_value);
/** Updates the given AggregationValue with new data. Assumes that
* the AggregationValue has been initialized */
void Update(Frame &frame, const SymbolTable &symbol_table,
ExpressionEvaluator &evaluator, AggregationValue &agg_value);
/** Checks if the given TypedValue is legal in MIN and MAX. If not
* an appropriate exception is thrown. */
void EnsureOkForMinMax(const TypedValue &value);
/** Checks if the given TypedValue is legal in AVG and SUM. If not
* an appropriate exception is thrown. */
void EnsureOkForAvgSum(const TypedValue &value);
};
};
} // namespace plan

View File

@ -6,6 +6,8 @@
#include <memory>
#include "utils/assert.hpp"
#include "utils/exceptions/not_yet_implemented.hpp"
#include "utils/hashing/fnv.hpp"
namespace query {
@ -441,6 +443,10 @@ TypedValue operator==(const TypedValue &a, const TypedValue &b) {
if (a.type() == TypedValue::Type::Null || b.type() == TypedValue::Type::Null)
return TypedValue::Null;
if (a.type() == TypedValue::Type::Map || b.type() == TypedValue::Type::Map) {
throw NotYetImplemented();
}
if (a.type() == TypedValue::Type::List ||
b.type() == TypedValue::Type::List) {
if (a.type() == TypedValue::Type::List &&
@ -731,4 +737,66 @@ TypedValue operator^(const TypedValue &a, const TypedValue &b) {
}
}
bool TypedValue::BoolEqual::operator()(const TypedValue &lhs,
const TypedValue &rhs) const {
if (lhs.type() == TypedValue::Type::Null &&
rhs.type() == TypedValue::Type::Null)
return true;
// legal comparisons are only between same types
// only int -> float is promoted
if (lhs.type() == rhs.type() ||
(lhs.type() == Type::Double && rhs.type() == Type::Int) ||
(rhs.type() == Type::Double && lhs.type() == Type::Int))
{
TypedValue equality_result = lhs == rhs;
switch (equality_result.type()) {
case TypedValue::Type::Bool:
return equality_result.Value<bool>();
case TypedValue::Type::Null:
// we already tested if both operands are null,
// so only one is null here. this evaluates to false equality
return false;
default:
permanent_fail(
"Equality between two TypedValues resulted in something other "
"then Null or bool");
}
}
return false;
}
size_t TypedValue::Hash::operator()(const TypedValue &value) const {
switch (value.type()) {
case TypedValue::Type::Null:
return 31;
case TypedValue::Type::Bool:
return std::hash<bool>{}(value.Value<bool>());
case TypedValue::Type::Int:
// we cast int to double for hashing purposes
// to be consistent with TypedValue equality
// in which (2.0 == 2) returns true
return std::hash<double>{}((double)value.Value<int64_t>());
case TypedValue::Type::Double:
return std::hash<double>{}(value.Value<double>());
case TypedValue::Type::String:
return std::hash<std::string>{}(value.Value<std::string>());
case TypedValue::Type::List: {
return FnvCollection<std::vector<TypedValue>, TypedValue, Hash>{}(
value.Value<std::vector<TypedValue>>());
}
case TypedValue::Type::Map:
throw NotYetImplemented();
case TypedValue::Type::Vertex:
return value.Value<VertexAccessor>().temporary_id();
case TypedValue::Type::Edge:
return value.Value<EdgeAccessor>().temporary_id();
case TypedValue::Type::Path:
throw NotYetImplemented();
break;
}
permanent_fail("Unhandled TypedValue.type() in hash function");
}
} // namespace query

View File

@ -5,6 +5,7 @@
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "storage/edge_accessor.hpp"
@ -29,6 +30,34 @@ typedef traversal_template::Path<VertexAccessor, EdgeAccessor> Path;
*/
class TypedValue : public TotalOrdering<TypedValue, TypedValue, TypedValue> {
public:
/** Custom TypedValue equality function that returns a bool
* (as opposed to returning TypedValue as the default equality does).
* This implementation treats two nulls as being equal and null
* not being equal to everything else.
*/
struct BoolEqual {
bool operator()(const TypedValue &left, const TypedValue &right) const;
};
/** Hash operator for TypedValue.
*
* Not injecting into std
* due to linking problems. If the implementation is in this header,
* then it implicitly instantiates TypedValue::Value<T> before
* explicit instantiation in .cpp file. If the implementation is in
* the .cpp file, it won't link.
*/
struct Hash {
size_t operator()(const TypedValue &value) const;
};
/**
* Unordered set of TypedValue items. Can contain at most one Null element,
* and treats an integral and floating point value as same if they are equal
* in the floating point domain (TypedValue operator== behaves the same).
* */
using unordered_set = std::unordered_set<TypedValue, Hash, BoolEqual>;
/** Private default constructor, makes Null */
TypedValue() : type_(Type::Null) {}
@ -163,7 +192,6 @@ TypedValue operator||(const TypedValue &a, const TypedValue &b);
// Be careful: since ^ is binary operator and || and && are logical operators
// they have different priority in c++.
TypedValue operator^(const TypedValue &a, const TypedValue &b);
// stream output
std::ostream &operator<<(std::ostream &os, const TypedValue::Type type);

View File

@ -15,21 +15,53 @@ namespace {
#ifdef MEMGRAPH64
template <class T>
uint64_t fnv(const T& data) {
return fnv1a64<T>(data);
__attribute__((unused)) uint64_t fnv(const std::string &s) {
return fnv1a64(s);
}
using HashType = uint64_t;
#elif
template <class T>
uint32_t fnv(const T& data) {
return fnv1a32<T>(data);
__attribute__((unused)) uint32_t fnv(const std::string &s) {
return fnv1a32(s);
}
using HashType = uint32_t;
#endif
}
/**
* Does FNV-like hashing on a collection. Not truly FNV
* because it operates on 8-bit elements, while this
* implementation uses size_t elements (collection item
* hash).
*
* https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
*
*
* @tparam TIterable A collection type that has begin() and end().
* @tparam TElement Type of element in the collection.
* @tparam THash Hash type (has operator() that accepts a 'const TEelement &'
* and returns size_t. Defaults to std::hash<TElement>.
* @param iterable A collection of elements.
* @param element_hash Function for hashing a single element.
* @return The hash of the whole collection.
*/
template <typename TIterable, typename TElement,
typename THash = std::hash<TElement>>
struct FnvCollection {
size_t operator()(const TIterable &iterable) const {
uint64_t hash = 14695981039346656037u;
THash element_hash;
for (const TElement &element : iterable) {
hash *= fnv_prime;
hash ^= element_hash(element);
}
return hash;
}
private:
static const uint64_t fnv_prime = 1099511628211u;
};

View File

@ -8,41 +8,21 @@ namespace {
#define OFFSET_BASIS32 2166136261u
#define FNV_PRIME32 16777619u
uint32_t fnv32(const unsigned char* const data, size_t n) {
__attribute__((unused)) uint32_t fnv32(const std::string &s) {
uint32_t hash = OFFSET_BASIS32;
for (size_t i = 0; i < n; ++i)
hash = (hash * FNV_PRIME32) xor (uint32_t) data[i];
for (size_t i = 0; i < s.size(); ++i)
hash = (hash * FNV_PRIME32) xor (uint32_t) s[i];
return hash;
}
template <class T>
uint32_t fnv32(const T& data) {
return fnv32(&data, sizeof(data));
}
template <>
__attribute__((unused)) uint32_t fnv32(const std::string& data) {
return fnv32((const unsigned char*)data.c_str(), data.size());
}
uint32_t fnv1a32(const unsigned char* const data, size_t n) {
__attribute__((unused)) uint32_t fnv1a32(const std::string &s) {
uint32_t hash = OFFSET_BASIS32;
for (size_t i = 0; i < n; ++i)
hash = (hash xor (uint32_t) data[i]) * FNV_PRIME32;
for (size_t i = 0; i < s.size(); ++i)
hash = (hash xor (uint32_t) s[i]) * FNV_PRIME32;
return hash;
}
template <class T>
uint32_t fnv1a32(const T& data) {
return fnv1a32(&data, sizeof(data));
}
template <>
__attribute__((unused)) uint32_t fnv1a32(const std::string& data) {
return fnv1a32((const unsigned char*)data.c_str(), data.size());
}
}

View File

@ -8,41 +8,21 @@ namespace {
#define OFFSET_BASIS64 14695981039346656037u
#define FNV_PRIME64 1099511628211u
uint64_t fnv64(const unsigned char* const data, size_t n) {
__attribute__((unused)) uint64_t fnv64(const std::string &s) {
uint64_t hash = OFFSET_BASIS64;
for (size_t i = 0; i < n; ++i)
hash = (hash * FNV_PRIME64) xor (uint64_t) data[i];
for (size_t i = 0; i < s.size(); ++i)
hash = (hash * FNV_PRIME64) xor (uint64_t) s[i];
return hash;
}
template <class T>
uint64_t fnv64(const T& data) {
return fnv64(&data, sizeof(data));
}
template <>
__attribute__((unused)) uint64_t fnv64(const std::string& data) {
return fnv64((const unsigned char*)data.c_str(), data.size());
}
uint64_t fnv1a64(const unsigned char* const data, size_t n) {
__attribute__((unused)) uint64_t fnv1a64(const std::string &s) {
uint64_t hash = OFFSET_BASIS64;
for (size_t i = 0; i < n; ++i)
hash = (hash xor (uint64_t) data[i]) * FNV_PRIME64;
for (size_t i = 0; i < s.size(); ++i)
hash = (hash xor (uint64_t) s[i]) * FNV_PRIME64;
return hash;
}
template <class T>
uint64_t fnv1a64(const T& data) {
return fnv1a64(&data, sizeof(data));
}
template <>
__attribute__((unused)) uint64_t fnv1a64(const std::string& data) {
return fnv1a64((const unsigned char*)data.c_str(), data.size());
}
}

View File

@ -38,8 +38,8 @@ int main(int argc, char **argv) {
auto elements = utils::random::generate_vector(generator, 1 << 16);
StringHashFunction hash1 = fnv64<std::string>;
StringHashFunction hash2 = fnv1a64<std::string>;
StringHashFunction hash1 = fnv64;
StringHashFunction hash2 = fnv1a64;
std::vector<StringHashFunction> funcs = {hash1, hash2};
BloomFilter<std::string, 128> bloom(funcs);

View File

@ -129,8 +129,8 @@ int main(int argc, char **argv) {
PairGenerator<StringGenerator, IntegerGenerator> psig(&sg, &ig);
PairGenerator<IntegerGenerator, StringGenerator> pisg(&ig, &sg);
StringHashFunction hash1 = fnv64<std::string>;
StringHashFunction hash2 = fnv1a64<std::string>;
StringHashFunction hash1 = fnv64;
StringHashFunction hash2 = fnv1a64;
std::vector<StringHashFunction> funcs = {hash1, hash2};
BloomFilter<std::string, 128> bloom_filter_(funcs);

View File

@ -9,8 +9,8 @@
using StringHashFunction = std::function<uint64_t(const std::string &)>;
TEST(BloomFilterTest, InsertContains) {
StringHashFunction hash1 = fnv64<std::string>;
StringHashFunction hash2 = fnv1a64<std::string>;
StringHashFunction hash1 = fnv64;
StringHashFunction hash2 = fnv1a64;
std::vector<StringHashFunction> funcs = {hash1, hash2};
BloomFilter<std::string, 64> bloom(funcs);

View File

@ -0,0 +1,382 @@
//
// Copyright 2017 Memgraph
// Created by Florijan Stamenkovic on 14.03.17.
//
#include <algorithm>
#include <iterator>
#include <memory>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "communication/result_stream_faker.hpp"
#include "dbms/dbms.hpp"
#include "query/context.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/logical/operator.hpp"
#include "query_plan_common.hpp"
using namespace query;
using namespace query::plan;
TEST(QueryPlan, Accumulate) {
// simulate the following two query execution on an empty db
// CREATE ({x:0})-[:T]->({x:0})
// MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x
// without accumulation we expected results to be [[1, 1], [2, 2]]
// with accumulation we expect them to be [[2, 2], [2, 2]]
auto check = [&](bool accumulate) {
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("x");
auto v1 = dba->insert_vertex();
v1.PropsSet(prop, 0);
auto v2 = dba->insert_vertex();
v2.PropsSet(prop, 0);
dba->insert_edge(v1, v2, dba->edge_type("T"));
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::BOTH, false, "m", false);
auto one = LITERAL(1);
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
auto set_n_p =
std::make_shared<plan::SetProperty>(r_m.op_, n_p, ADD(n_p, one));
auto m_p = PROPERTY_LOOKUP("m", prop);
symbol_table[*m_p->expression_] = r_m.node_sym_;
auto set_m_p =
std::make_shared<plan::SetProperty>(set_n_p, m_p, ADD(m_p, one));
std::shared_ptr<LogicalOperator> last_op = set_m_p;
if (accumulate) {
last_op = std::make_shared<Accumulate>(
last_op, std::vector<Symbol>{n.sym_, r_m.node_sym_});
}
auto n_p_ne = NEXPR("n.p", n_p);
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne");
auto m_p_ne = NEXPR("m.p", m_p);
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne");
auto produce = MakeProduce(last_op, n_p_ne, m_p_ne);
ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba);
std::vector<int> results_data;
for (const auto &row : results.GetResults())
for (const auto &column : row)
results_data.emplace_back(column.Value<int64_t>());
if (accumulate)
EXPECT_THAT(results_data, testing::ElementsAre(2, 2, 2, 2));
else
EXPECT_THAT(results_data, testing::ElementsAre(1, 1, 2, 2));
};
check(false);
check(true);
}
TEST(QueryPlan, AccumulateAdvance) {
// we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m'
// to get correct results we need to advance the command
auto check = [&](bool advance) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
symbol_table[*node->identifier_] = sym_n;
auto create = std::make_shared<CreateNode>(node, nullptr);
auto accumulate = std::make_shared<Accumulate>(
create, std::vector<Symbol>{sym_n}, advance);
auto match = MakeScanAll(storage, symbol_table, "m", accumulate);
EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table));
};
check(false);
check(true);
}
std::shared_ptr<Produce> MakeAggregationProduce(
std::shared_ptr<LogicalOperator> input, SymbolTable &symbol_table,
AstTreeStorage &storage, const std::vector<Expression *> aggr_inputs,
const std::vector<Aggregation::Op> aggr_ops,
const std::vector<Expression *> group_by_exprs,
const std::vector<Symbol> remember) {
permanent_assert(aggr_inputs.size() == aggr_ops.size(),
"Provide as many aggr inputs as aggr ops");
// prepare all the aggregations
std::vector<Aggregate::Element> aggregates;
std::vector<NamedExpression *> named_expressions;
auto aggr_inputs_it = aggr_inputs.begin();
for (auto aggr_op : aggr_ops) {
// TODO change this from using IDENT to using AGGREGATION
// once AGGREGATION is handled properly in ExpressionEvaluation
auto named_expr = NEXPR("", IDENT("aggregation"));
named_expressions.push_back(named_expr);
symbol_table[*named_expr->expression_] =
symbol_table.CreateSymbol("aggregation");
symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression");
aggregates.emplace_back(*aggr_inputs_it++, aggr_op,
symbol_table[*named_expr->expression_]);
}
// Produce will also evaluate group_by expressions
// and return them after the aggregations
for (auto group_by_expr : group_by_exprs) {
auto named_expr = NEXPR("", group_by_expr);
named_expressions.push_back(named_expr);
symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression");
}
auto aggregation =
std::make_shared<Aggregate>(input, aggregates, group_by_exprs, remember);
return std::make_shared<Produce>(aggregation, named_expressions);
}
TEST(QueryPlan, AggregateOps) {
Dbms dbms;
auto dba = dbms.active();
// setup is several nodes most of which have an int property set
// we will take the sum, avg, min, max and count
// we won't group by anything
auto prop = dba->property("prop");
dba->insert_vertex().PropsSet(prop, 4);
dba->insert_vertex().PropsSet(prop, 7);
dba->insert_vertex().PropsSet(prop, 12);
// a missing property (null) gets ignored by all aggregations
dba->insert_vertex();
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
auto produce = MakeAggregationProduce(
n.op_, symbol_table, storage, std::vector<Expression *>(5, n_p),
{Aggregation::Op::COUNT, Aggregation::Op::MIN, Aggregation::Op::MAX,
Aggregation::Op::SUM, Aggregation::Op::AVG},
{}, {});
// checks
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(results.size(), 1);
ASSERT_EQ(results[0].size(), 5);
// count
ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][0].Value<int64_t>(), 3);
// min
ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][1].Value<int64_t>(), 4);
// max
ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][2].Value<int64_t>(), 12);
// sum
ASSERT_EQ(results[0][3].type(), TypedValue::Type::Int);
EXPECT_EQ(results[0][3].Value<int64_t>(), 23);
// avg
ASSERT_EQ(results[0][4].type(), TypedValue::Type::Double);
EXPECT_FLOAT_EQ(results[0][4].Value<double>(), 23 / 3.0);
}
TEST(QueryPlan, AggregateGroupByValues) {
// tests that distinct groups are aggregated properly
// for values of all types
// also test the "remember" part of the Aggregation API
// as final results are obtained via a property lookup of
// a remembered node
Dbms dbms;
auto dba = dbms.active();
// a vector of TypedValue to be set as property values on vertices
// most of them should result in a distinct group (commented where not)
std::vector<TypedValue> group_by_vals;
group_by_vals.emplace_back(4);
group_by_vals.emplace_back(7);
group_by_vals.emplace_back(7.3);
group_by_vals.emplace_back(7.2);
group_by_vals.emplace_back("Johhny");
group_by_vals.emplace_back("Jane");
group_by_vals.emplace_back("1");
group_by_vals.emplace_back(true);
group_by_vals.emplace_back(false);
group_by_vals.emplace_back(std::vector<TypedValue>{1});
group_by_vals.emplace_back(std::vector<TypedValue>{1, 2});
group_by_vals.emplace_back(std::vector<TypedValue>{2, 1});
group_by_vals.emplace_back(TypedValue::Null);
// should NOT result in another group because 7.0 == 7
group_by_vals.emplace_back(7.0);
// should NOT result in another group
group_by_vals.emplace_back(std::vector<TypedValue>{1, 2.0});
// generate a lot of vertices and set props on them
auto prop = dba->property("prop");
for (int i = 0; i < 1000; ++i)
dba->insert_vertex().PropsSet(prop,
group_by_vals[i % group_by_vals.size()]);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
auto produce =
MakeAggregationProduce(n.op_, symbol_table, storage, {n_p},
{Aggregation::Op::COUNT}, {n_p}, {n.sym_});
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
ASSERT_EQ(results.size(), group_by_vals.size() - 2);
TypedValue::unordered_set result_group_bys;
for (const auto &row : results) {
ASSERT_EQ(2, row.size());
result_group_bys.insert(row[1]);
}
ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2);
EXPECT_TRUE(std::is_permutation(
group_by_vals.begin(), group_by_vals.end() - 2, result_group_bys.begin(),
TypedValue::BoolEqual{}));
}
TEST(QueryPlan, AggregateMultipleGroupBy) {
// in this test we have 3 different properties that have different values
// for different records and assert that we get the correct combination
// of values in our groups
Dbms dbms;
auto dba = dbms.active();
auto prop1 = dba->property("prop1");
auto prop2 = dba->property("prop2");
auto prop3 = dba->property("prop3");
for (int i = 0; i < 2 * 3 * 5; ++i) {
auto v = dba->insert_vertex();
v.PropsSet(prop1, (bool)(i % 2));
v.PropsSet(prop2, i % 3);
v.PropsSet(prop3, "value" + std::to_string(i % 5));
}
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// match all nodes and perform aggregations
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p1 = PROPERTY_LOOKUP("n", prop1);
auto n_p2 = PROPERTY_LOOKUP("n", prop2);
auto n_p3 = PROPERTY_LOOKUP("n", prop3);
symbol_table[*n_p1->expression_] = n.sym_;
symbol_table[*n_p2->expression_] = n.sym_;
symbol_table[*n_p3->expression_] = n.sym_;
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1},
{Aggregation::Op::COUNT},
{n_p1, n_p2, n_p3}, {n.sym_});
auto results = CollectProduce(produce, symbol_table, *dba).GetResults();
EXPECT_EQ(results.size(), 2 * 3 * 5);
}
TEST(QueryPlan, AggregateAdvance) {
// we simulate 'CREATE (n {x: 42}) WITH count(n.x) AS c MATCH (m) RETURN m,
// m.x, c'
// to get correct results we need to advance the command in aggregation
// since we only test aggregation, we'll simplify the logical plan and only
// check the count and not all the results
auto check = [&](bool advance) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
symbol_table[*node->identifier_] = sym_n;
auto create = std::make_shared<CreateNode>(node, nullptr);
auto aggr_sym = symbol_table.CreateSymbol("aggr_sym");
auto n_p = PROPERTY_LOOKUP("n", dba->property("x"));
symbol_table[*n_p->expression_] = sym_n;
auto aggregate = std::make_shared<Aggregate>(
create, std::vector<Aggregate::Element>{Aggregate::Element{
n_p, Aggregation::Op::COUNT, aggr_sym}},
std::vector<Expression *>{}, std::vector<Symbol>{}, advance);
auto match = MakeScanAll(storage, symbol_table, "m", aggregate);
EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table));
};
// check(false);
check(true);
}
TEST(QueryPlan, AggregateTypes) {
// testing exceptions that can get emitted by an aggregation
// does not check all combinations that can result in an exception
// (that logic is defined and tested by TypedValue)
Dbms dbms;
auto dba = dbms.active();
auto p1 = dba->property("p1"); // has only string props
dba->insert_vertex().PropsSet(p1, "string");
dba->insert_vertex().PropsSet(p1, "str2");
auto p2 = dba->property("p2"); // combines int and bool
dba->insert_vertex().PropsSet(p2, 42);
dba->insert_vertex().PropsSet(p2, true);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto n_p1 = PROPERTY_LOOKUP("n", p1);
symbol_table[*n_p1->expression_] = n.sym_;
auto n_p2 = PROPERTY_LOOKUP("n", p2);
symbol_table[*n_p2->expression_] = n.sym_;
auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) {
auto produce = MakeAggregationProduce(n.op_, symbol_table, storage,
{expression}, {aggr_op}, {}, {});
CollectProduce(produce, symbol_table, *dba).GetResults();
};
// everythin except for COUNT fails on a Vertex
auto n_id = n_p1->expression_;
aggregate(n_id, Aggregation::Op::COUNT);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), TypedValueException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), TypedValueException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), TypedValueException);
EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), TypedValueException);
// on strings AVG and SUM fail
aggregate(n_p1, Aggregation::Op::COUNT);
aggregate(n_p1, Aggregation::Op::MIN);
aggregate(n_p1, Aggregation::Op::MAX);
EXPECT_THROW(aggregate(n_p1, Aggregation::Op::AVG), TypedValueException);
EXPECT_THROW(aggregate(n_p1, Aggregation::Op::SUM), TypedValueException);
// combination of int and bool, everything except count fails
aggregate(n_p2, Aggregation::Op::COUNT);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MIN), TypedValueException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MAX), TypedValueException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), TypedValueException);
EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), TypedValueException);
}

View File

@ -0,0 +1,128 @@
//
// Copyright 2017 Memgraph
// Created by Florijan Stamenkovic on 14.03.17.
//
#include <iterator>
#include <memory>
#include <vector>
#include "query/frontend/semantic/symbol_table.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query_common.hpp"
using namespace query;
using namespace query::plan;
/**
* Helper function that collects all the results from the given
* Produce into a ResultStreamFaker and returns that object.
*
* @param produce
* @param symbol_table
* @param db_accessor
* @return
*/
auto CollectProduce(std::shared_ptr<Produce> produce, SymbolTable &symbol_table,
GraphDbAccessor &db_accessor) {
ResultStreamFaker stream;
Frame frame(symbol_table.max_position());
// top level node in the operator tree is a produce (return)
// so stream out results
// generate header
std::vector<std::string> header;
for (auto named_expression : produce->named_expressions())
header.push_back(named_expression->name_);
stream.Header(header);
// collect the symbols from the return clause
std::vector<Symbol> symbols;
for (auto named_expression : produce->named_expressions())
symbols.emplace_back(symbol_table[*named_expression]);
// stream out results
auto cursor = produce->MakeCursor(db_accessor);
while (cursor->Pull(frame, symbol_table)) {
std::vector<TypedValue> values;
for (auto &symbol : symbols) values.emplace_back(frame[symbol]);
stream.Result(values);
}
stream.Summary({{std::string("type"), TypedValue("r")}});
return stream;
}
int PullAll(std::shared_ptr<LogicalOperator> logical_op, GraphDbAccessor &db,
SymbolTable symbol_table) {
Frame frame(symbol_table.max_position());
auto cursor = logical_op->MakeCursor(db);
int count = 0;
while (cursor->Pull(frame, symbol_table)) count++;
return count;
}
template <typename... TNamedExpressions>
auto MakeProduce(std::shared_ptr<LogicalOperator> input,
TNamedExpressions... named_expressions) {
return std::make_shared<Produce>(
input, std::vector<NamedExpression *>{named_expressions...});
}
struct ScanAllTuple {
NodeAtom *node_;
std::shared_ptr<LogicalOperator> op_;
Symbol sym_;
};
/**
* Creates and returns a tuple of stuff for a scan-all starting
* from the node with the given name.
*
* Returns (node_atom, scan_all_logical_op, symbol).
*/
ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table,
const std::string &identifier,
std::shared_ptr<LogicalOperator> input = {nullptr}) {
auto node = NODE(identifier);
auto logical_op = std::make_shared<ScanAll>(node, input);
auto symbol = symbol_table.CreateSymbol(identifier);
symbol_table[*node->identifier_] = symbol;
// return std::make_tuple(node, logical_op, symbol);
return ScanAllTuple{node, logical_op, symbol};
}
struct ExpandTuple {
EdgeAtom *edge_;
Symbol edge_sym_;
NodeAtom *node_;
Symbol node_sym_;
std::shared_ptr<LogicalOperator> op_;
};
ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table,
std::shared_ptr<LogicalOperator> input,
Symbol input_symbol, const std::string &edge_identifier,
EdgeAtom::Direction direction, bool edge_cycle,
const std::string &node_identifier, bool node_cycle) {
auto edge = EDGE(edge_identifier, direction);
auto edge_sym = symbol_table.CreateSymbol(edge_identifier);
symbol_table[*edge->identifier_] = edge_sym;
auto node = NODE(node_identifier);
auto node_sym = symbol_table.CreateSymbol(node_identifier);
symbol_table[*node->identifier_] = node_sym;
auto op = std::make_shared<Expand>(node, edge, input, input_symbol,
node_cycle, edge_cycle);
return ExpandTuple{edge, edge_sym, node, node_sym, op};
}
template <typename TIterable>
auto CountIterable(TIterable iterable) {
return std::distance(iterable.begin(), iterable.end());
}

View File

@ -15,300 +15,14 @@
#include "query/context.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/logical/planner.hpp"
#include "query/frontend/logical/operator.hpp"
#include "query_common.hpp"
#include "query_plan_common.hpp"
using namespace query;
using namespace query::plan;
/**
* Helper function that collects all the results from the given
* Produce into a ResultStreamFaker and returns that object.
*
* @param produce
* @param symbol_table
* @param db_accessor
* @return
*/
auto CollectProduce(std::shared_ptr<Produce> produce, SymbolTable &symbol_table,
GraphDbAccessor &db_accessor) {
ResultStreamFaker stream;
Frame frame(symbol_table.max_position());
// top level node in the operator tree is a produce (return)
// so stream out results
// generate header
std::vector<std::string> header;
for (auto named_expression : produce->named_expressions())
header.push_back(named_expression->name_);
stream.Header(header);
// collect the symbols from the return clause
std::vector<Symbol> symbols;
for (auto named_expression : produce->named_expressions())
symbols.emplace_back(symbol_table[*named_expression]);
// stream out results
auto cursor = produce->MakeCursor(db_accessor);
while (cursor->Pull(frame, symbol_table)) {
std::vector<TypedValue> values;
for (auto &symbol : symbols) values.emplace_back(frame[symbol]);
stream.Result(values);
}
stream.Summary({{std::string("type"), TypedValue("r")}});
return stream;
}
int PullAll(std::shared_ptr<LogicalOperator> logical_op, GraphDbAccessor &db,
SymbolTable symbol_table) {
Frame frame(symbol_table.max_position());
auto cursor = logical_op->MakeCursor(db);
int count = 0;
while (cursor->Pull(frame, symbol_table)) count++;
return count;
}
template <typename... TNamedExpressions>
auto MakeProduce(std::shared_ptr<LogicalOperator> input,
TNamedExpressions... named_expressions) {
return std::make_shared<Produce>(
input, std::vector<NamedExpression *>{named_expressions...});
}
struct ScanAllTuple {
NodeAtom *node_;
std::shared_ptr<LogicalOperator> op_;
Symbol sym_;
};
/**
* Creates and returns a tuple of stuff for a scan-all starting
* from the node with the given name.
*
* Returns (node_atom, scan_all_logical_op, symbol).
*/
ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table,
const std::string &identifier,
std::shared_ptr<LogicalOperator> input = {nullptr}) {
auto node = NODE(identifier);
auto logical_op = std::make_shared<ScanAll>(node, input);
auto symbol = symbol_table.CreateSymbol(identifier);
symbol_table[*node->identifier_] = symbol;
// return std::make_tuple(node, logical_op, symbol);
return ScanAllTuple{node, logical_op, symbol};
}
struct ExpandTuple {
EdgeAtom *edge_;
Symbol edge_sym_;
NodeAtom *node_;
Symbol node_sym_;
std::shared_ptr<LogicalOperator> op_;
};
ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table,
std::shared_ptr<LogicalOperator> input,
Symbol input_symbol, const std::string &edge_identifier,
EdgeAtom::Direction direction, bool edge_cycle,
const std::string &node_identifier, bool node_cycle) {
auto edge = EDGE(edge_identifier, direction);
auto edge_sym = symbol_table.CreateSymbol(edge_identifier);
symbol_table[*edge->identifier_] = edge_sym;
auto node = NODE(node_identifier);
auto node_sym = symbol_table.CreateSymbol(node_identifier);
symbol_table[*node->identifier_] = node_sym;
auto op = std::make_shared<Expand>(node, edge, input, input_symbol,
node_cycle, edge_cycle);
return ExpandTuple{edge, edge_sym, node, node_sym, op};
}
template <typename TIterable>
auto CountIterable(TIterable iterable) {
return std::distance(iterable.begin(), iterable.end());
}
/*
* Actual tests start here.
*/
TEST(Interpreter, MatchReturn) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
dba->insert_vertex();
dba->insert_vertex();
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto scan_all = MakeScanAll(storage, symbol_table, "n");
auto output = NEXPR("n", IDENT("n"));
auto produce = MakeProduce(scan_all.op_, output);
symbol_table[*output->expression_] = scan_all.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 2);
}
TEST(Interpreter, MatchReturnCartesian) {
Dbms dbms;
auto dba = dbms.active();
dba->insert_vertex().add_label(dba->label("l1"));
dba->insert_vertex().add_label(dba->label("l2"));
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto m = MakeScanAll(storage, symbol_table, "m", n.op_);
auto return_n = NEXPR("n", IDENT("n"));
symbol_table[*return_n->expression_] = n.sym_;
symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1");
auto return_m = NEXPR("m", IDENT("m"));
symbol_table[*return_m->expression_] = m.sym_;
symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2");
auto produce = MakeProduce(m.op_, return_n, return_m);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
auto result_data = result.GetResults();
EXPECT_EQ(result_data.size(), 4);
// ensure the result ordering is OK:
// "n" from the results is the same for the first two rows, while "m" isn't
EXPECT_EQ(result_data[0][0].Value<VertexAccessor>(),
result_data[1][0].Value<VertexAccessor>());
EXPECT_NE(result_data[0][1].Value<VertexAccessor>(),
result_data[1][1].Value<VertexAccessor>());
}
TEST(Interpreter, StandaloneReturn) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
dba->insert_vertex();
dba->insert_vertex();
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto output = NEXPR("n", LITERAL(42));
auto produce = MakeProduce(std::shared_ptr<LogicalOperator>(nullptr), output);
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
EXPECT_EQ(result.GetResults()[0].size(), 1);
EXPECT_EQ(result.GetResults()[0][0].Value<int64_t>(), 42);
}
TEST(Interpreter, NodeFilterLabelsAndProperties) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
GraphDbTypes::Label label = dba->label("Label");
GraphDbTypes::Property property = dba->property("Property");
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto v3 = dba->insert_vertex();
auto v4 = dba->insert_vertex();
auto v5 = dba->insert_vertex();
dba->insert_vertex();
// test all combination of (label | no_label) * (no_prop | wrong_prop |
// right_prop)
// only v1 will have the right labels
v1.add_label(label);
v2.add_label(label);
v3.add_label(label);
v1.PropsSet(property, 42);
v2.PropsSet(property, 1);
v4.PropsSet(property, 42);
v5.PropsSet(property, 1);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// make a scan all
auto n = MakeScanAll(storage, symbol_table, "n");
n.node_->labels_.emplace_back(label);
n.node_->properties_[property] = LITERAL(42);
// node filtering
auto node_filter = std::make_shared<NodeFilter>(n.op_, n.sym_, n.node_);
// make a named expression and a produce
auto output = NEXPR("x", IDENT("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(node_filter, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
}
TEST(Interpreter, NodeFilterMultipleLabels) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
GraphDbTypes::Label label1 = dba->label("label1");
GraphDbTypes::Label label2 = dba->label("label2");
GraphDbTypes::Label label3 = dba->label("label3");
// the test will look for nodes that have label1 and label2
dba->insert_vertex(); // NOT accepted
dba->insert_vertex().add_label(label1); // NOT accepted
dba->insert_vertex().add_label(label2); // NOT accepted
dba->insert_vertex().add_label(label3); // NOT accepted
auto v1 = dba->insert_vertex(); // YES accepted
v1.add_label(label1);
v1.add_label(label2);
auto v2 = dba->insert_vertex(); // NOT accepted
v2.add_label(label1);
v2.add_label(label3);
auto v3 = dba->insert_vertex(); // YES accepted
v3.add_label(label1);
v3.add_label(label2);
v3.add_label(label3);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// make a scan all
auto n = MakeScanAll(storage, symbol_table, "n");
n.node_->labels_.emplace_back(label1);
n.node_->labels_.emplace_back(label2);
// node filtering
auto node_filter = std::make_shared<NodeFilter>(n.op_, n.sym_, n.node_);
// make a named expression and a produce
auto output = NEXPR("n", IDENT("n"));
auto produce = MakeProduce(node_filter, output);
// fill up the symbol table
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output->expression_] = n.sym_;
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 2);
}
TEST(Interpreter, CreateNodeWithAttributes) {
TEST(QueryPlan, CreateNodeWithAttributes) {
Dbms dbms;
auto dba = dbms.active();
@ -341,7 +55,7 @@ TEST(Interpreter, CreateNodeWithAttributes) {
EXPECT_EQ(vertex_count, 1);
}
TEST(Interpreter, CreateReturn) {
TEST(QueryPlan, CreateReturn) {
// test CREATE (n:Person {age: 42}) RETURN n, n.age
Dbms dbms;
auto dba = dbms.active();
@ -384,7 +98,7 @@ TEST(Interpreter, CreateReturn) {
EXPECT_EQ(1, CountIterable(dba->vertices()));
}
TEST(Interpreter, CreateExpand) {
TEST(QueryPlan, CreateExpand) {
Dbms dbms;
auto dba = dbms.active();
@ -457,7 +171,7 @@ TEST(Interpreter, CreateExpand) {
}
}
TEST(Interpreter, MatchCreateNode) {
TEST(QueryPlan, MatchCreateNode) {
Dbms dbms;
auto dba = dbms.active();
@ -484,7 +198,7 @@ TEST(Interpreter, MatchCreateNode) {
EXPECT_EQ(CountIterable(dba->vertices()), 6);
}
TEST(Interpreter, MatchCreateExpand) {
TEST(QueryPlan, MatchCreateExpand) {
Dbms dbms;
auto dba = dbms.active();
@ -535,223 +249,7 @@ TEST(Interpreter, MatchCreateExpand) {
test_create_path(true, 0, 6);
}
TEST(Interpreter, Expand) {
Dbms dbms;
auto dba = dbms.active();
// make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2)
auto v1 = dba->insert_vertex();
v1.add_label((GraphDbTypes::Label)1);
auto v2 = dba->insert_vertex();
v2.add_label((GraphDbTypes::Label)2);
auto v3 = dba->insert_vertex();
v3.add_label((GraphDbTypes::Label)3);
auto edge_type = dba->edge_type("Edge");
dba->insert_edge(v1, v2, edge_type);
dba->insert_edge(v1, v3, edge_type);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto test_expand = [&](EdgeAtom::Direction direction,
int expected_result_count) {
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction,
false, "m", false);
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
symbol_table[*output->expression_] = r_m.node_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(r_m.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), expected_result_count);
};
test_expand(EdgeAtom::Direction::RIGHT, 2);
test_expand(EdgeAtom::Direction::LEFT, 2);
test_expand(EdgeAtom::Direction::BOTH, 4);
}
TEST(Interpreter, ExpandNodeCycle) {
Dbms dbms;
auto dba = dbms.active();
// make a graph (v1)->(v2) that
// has a recursive edge (v1)->(v1)
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto edge_type = dba->edge_type("Edge");
dba->insert_edge(v1, v1, edge_type);
dba->insert_edge(v1, v2, edge_type);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto test_cycle = [&](bool with_cycle, int expected_result_count) {
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::RIGHT, false, "n", with_cycle);
if (with_cycle)
symbol_table[*r_n.node_->identifier_] =
symbol_table[*n.node_->identifier_];
// make a named expression and a produce
auto output = NEXPR("n", IDENT("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(r_n.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), expected_result_count);
};
test_cycle(true, 1);
test_cycle(false, 2);
}
TEST(Interpreter, ExpandEdgeCycle) {
Dbms dbms;
auto dba = dbms.active();
// make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2)
auto v1 = dba->insert_vertex();
v1.add_label((GraphDbTypes::Label)1);
auto v2 = dba->insert_vertex();
v2.add_label((GraphDbTypes::Label)2);
auto v3 = dba->insert_vertex();
v3.add_label((GraphDbTypes::Label)3);
auto edge_type = dba->edge_type("Edge");
dba->insert_edge(v1, v2, edge_type);
dba->insert_edge(v1, v3, edge_type);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto test_cycle = [&](bool with_cycle, int expected_result_count) {
auto i = MakeScanAll(storage, symbol_table, "i");
auto r_j = MakeExpand(storage, symbol_table, i.op_, i.sym_, "r",
EdgeAtom::Direction::BOTH, false, "j", false);
auto r_k = MakeExpand(storage, symbol_table, r_j.op_, r_j.node_sym_, "r",
EdgeAtom::Direction::BOTH, with_cycle, "k", false);
if (with_cycle)
symbol_table[*r_k.edge_->identifier_] =
symbol_table[*r_j.edge_->identifier_];
// make a named expression and a produce
auto output = NEXPR("r", IDENT("r"));
symbol_table[*output->expression_] = r_j.edge_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(r_k.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), expected_result_count);
};
test_cycle(true, 4);
test_cycle(false, 6);
}
TEST(Interpreter, EdgeFilter) {
Dbms dbms;
auto dba = dbms.active();
// make an N-star expanding from (v1)
// where only one edge will qualify
// and there are all combinations of
// (edge_type yes|no) * (property yes|absent|no)
std::vector<GraphDbTypes::EdgeType> edge_types;
for (int j = 0; j < 2; ++j)
edge_types.push_back(dba->edge_type("et" + std::to_string(j)));
std::vector<VertexAccessor> vertices;
for (int i = 0; i < 7; ++i) vertices.push_back(dba->insert_vertex());
GraphDbTypes::Property prop = dba->property("prop");
std::vector<EdgeAccessor> edges;
for (int i = 0; i < 6; ++i) {
edges.push_back(
dba->insert_edge(vertices[0], vertices[i + 1], edge_types[i % 2]));
switch (i % 3) {
case 0:
edges.back().PropsSet(prop, 42);
break;
case 1:
edges.back().PropsSet(prop, 100);
break;
default:
break;
}
}
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// define an operator tree for query
// MATCH (n)-[r]->(m) RETURN m
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::RIGHT, false, "m", false);
r_m.edge_->edge_types_.push_back(edge_types[0]);
r_m.edge_->properties_[prop] = LITERAL(42);
auto edge_filter =
std::make_shared<EdgeFilter>(r_m.op_, r_m.edge_sym_, r_m.edge_);
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
symbol_table[*output->expression_] = r_m.node_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(edge_filter, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
}
TEST(Interpreter, EdgeFilterMultipleTypes) {
Dbms dbms;
auto dba = dbms.active();
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto type_1 = dba->edge_type("type_1");
auto type_2 = dba->edge_type("type_2");
auto type_3 = dba->edge_type("type_3");
dba->insert_edge(v1, v2, type_1);
dba->insert_edge(v1, v2, type_2);
dba->insert_edge(v1, v2, type_3);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// make a scan all
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::RIGHT, false, "m", false);
// add a property filter
auto edge_filter =
std::make_shared<EdgeFilter>(r_m.op_, r_m.edge_sym_, r_m.edge_);
r_m.edge_->edge_types_.push_back(type_1);
r_m.edge_->edge_types_.push_back(type_2);
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
auto produce = MakeProduce(edge_filter, output);
// fill up the symbol table
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output->expression_] = r_m.node_sym_;
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 2);
}
TEST(Interpreter, Delete) {
TEST(QueryPlan, Delete) {
Dbms dbms;
auto dba = dbms.active();
@ -826,7 +324,7 @@ TEST(Interpreter, Delete) {
}
}
TEST(Interpreter, DeleteReturn) {
TEST(QueryPlan, DeleteReturn) {
Dbms dbms;
auto dba = dbms.active();
@ -864,36 +362,7 @@ TEST(Interpreter, DeleteReturn) {
EXPECT_EQ(0, CountIterable(dba->vertices()));
}
TEST(Interpreter, Filter) {
Dbms dbms;
auto dba = dbms.active();
// add a 6 nodes with property 'prop', 2 have true as value
GraphDbTypes::Property property = dba->property("Property");
for (int i = 0; i < 6; ++i)
dba->insert_vertex().PropsSet(property, i % 3 == 0);
dba->insert_vertex(); // prop not set, gives NULL
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto e =
storage.Create<PropertyLookup>(storage.Create<Identifier>("n"), property);
symbol_table[*e->expression_] = n.sym_;
auto f = std::make_shared<Filter>(n.op_, e);
auto output =
storage.Create<NamedExpression>("x", storage.Create<Identifier>("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(f, output);
EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2);
}
TEST(Interpreter, SetProperty) {
TEST(QueryPlan, SetProperty) {
Dbms dbms;
auto dba = dbms.active();
@ -943,7 +412,7 @@ TEST(Interpreter, SetProperty) {
}
}
TEST(Interpreter, SetProperties) {
TEST(QueryPlan, SetProperties) {
auto test_set_properties = [](bool update) {
Dbms dbms;
auto dba = dbms.active();
@ -1013,7 +482,7 @@ TEST(Interpreter, SetProperties) {
test_set_properties(false);
}
TEST(Interpreter, SetLabels) {
TEST(QueryPlan, SetLabels) {
Dbms dbms;
auto dba = dbms.active();
@ -1040,7 +509,7 @@ TEST(Interpreter, SetLabels) {
}
}
TEST(Interpreter, RemoveProperty) {
TEST(QueryPlan, RemoveProperty) {
Dbms dbms;
auto dba = dbms.active();
@ -1092,7 +561,7 @@ TEST(Interpreter, RemoveProperty) {
}
}
TEST(Interpreter, RemoveLabels) {
TEST(QueryPlan, RemoveLabels) {
Dbms dbms;
auto dba = dbms.active();
@ -1124,7 +593,7 @@ TEST(Interpreter, RemoveLabels) {
}
}
TEST(Interpreter, NodeFilterSet) {
TEST(QueryPlan, NodeFilterSet) {
Dbms dbms;
auto dba = dbms.active();
// Create a graph such that (v1 {prop: 42}) is connected to v2 and v3.
@ -1162,7 +631,7 @@ TEST(Interpreter, NodeFilterSet) {
EXPECT_TRUE(prop_eq.Value<bool>());
}
TEST(Interpreter, FilterRemove) {
TEST(QueryPlan, FilterRemove) {
Dbms dbms;
auto dba = dbms.active();
// Create a graph such that (v1 {prop: 42}) is connected to v2 and v3.
@ -1198,7 +667,7 @@ TEST(Interpreter, FilterRemove) {
EXPECT_EQ(v1.PropsAt(prop).type(), PropertyValue::Type::Null);
}
TEST(Interpreter, SetRemove) {
TEST(QueryPlan, SetRemove) {
Dbms dbms;
auto dba = dbms.active();
auto v = dba->insert_vertex();
@ -1222,132 +691,3 @@ TEST(Interpreter, SetRemove) {
EXPECT_FALSE(v.has_label(label1));
EXPECT_FALSE(v.has_label(label2));
}
TEST(Interpreter, ExpandUniquenessFilter) {
Dbms dbms;
auto dba = dbms.active();
// make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1)
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto edge_type = dba->edge_type("edge_type");
dba->insert_edge(v1, v2, edge_type);
dba->insert_edge(v1, v1, edge_type);
dba->advance_command();
auto check_expand_results = [&](bool vertex_uniqueness,
bool edge_uniqueness) {
AstTreeStorage storage;
SymbolTable symbol_table;
auto n1 = MakeScanAll(storage, symbol_table, "n1");
auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1",
EdgeAtom::Direction::RIGHT, false, "n2", false);
std::shared_ptr<LogicalOperator> last_op = r1_n2.op_;
if (vertex_uniqueness)
last_op = std::make_shared<ExpandUniquenessFilter<VertexAccessor>>(
last_op, r1_n2.node_sym_, std::vector<Symbol>{n1.sym_});
auto r2_n3 =
MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2",
EdgeAtom::Direction::RIGHT, false, "n3", false);
last_op = r2_n3.op_;
if (edge_uniqueness)
last_op = std::make_shared<ExpandUniquenessFilter<EdgeAccessor>>(
last_op, r2_n3.edge_sym_, std::vector<Symbol>{r1_n2.edge_sym_});
if (vertex_uniqueness)
last_op = std::make_shared<ExpandUniquenessFilter<VertexAccessor>>(
last_op, r2_n3.node_sym_,
std::vector<Symbol>{n1.sym_, r1_n2.node_sym_});
return PullAll(last_op, *dba, symbol_table);
};
EXPECT_EQ(2, check_expand_results(false, false));
EXPECT_EQ(0, check_expand_results(true, false));
EXPECT_EQ(1, check_expand_results(false, true));
}
TEST(Interpreter, Accumulate) {
// simulate the following two query execution on an empty db
// CREATE ({x:0})-[:T]->({x:0})
// MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x
// without accumulation we expected results to be [[1, 1], [2, 2]]
// with accumulation we expect them to be [[2, 2], [2, 2]]
auto check = [&](bool accumulate) {
Dbms dbms;
auto dba = dbms.active();
auto prop = dba->property("x");
auto v1 = dba->insert_vertex();
v1.PropsSet(prop, 0);
auto v2 = dba->insert_vertex();
v2.PropsSet(prop, 0);
dba->insert_edge(v1, v2, dba->edge_type("T"));
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::BOTH, false, "m", false);
auto one = LITERAL(1);
auto n_p = PROPERTY_LOOKUP("n", prop);
symbol_table[*n_p->expression_] = n.sym_;
auto set_n_p =
std::make_shared<plan::SetProperty>(r_m.op_, n_p, ADD(n_p, one));
auto m_p = PROPERTY_LOOKUP("m", prop);
symbol_table[*m_p->expression_] = r_m.node_sym_;
auto set_m_p =
std::make_shared<plan::SetProperty>(set_n_p, m_p, ADD(m_p, one));
std::shared_ptr<LogicalOperator> last_op = set_m_p;
if (accumulate) {
last_op = std::make_shared<Accumulate>(
last_op, std::vector<Symbol>{n.sym_, r_m.node_sym_});
}
auto n_p_ne = NEXPR("n.p", n_p);
symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne");
auto m_p_ne = NEXPR("m.p", m_p);
symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne");
auto produce = MakeProduce(last_op, n_p_ne, m_p_ne);
ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba);
std::vector<int> results_data;
for (const auto &row : results.GetResults())
for (const auto &column : row)
results_data.emplace_back(column.Value<int64_t>());
if (accumulate)
EXPECT_THAT(results_data, testing::ElementsAre(2, 2, 2, 2));
else
EXPECT_THAT(results_data, testing::ElementsAre(1, 1, 2, 2));
};
check(false);
check(true);
}
TEST(Interpreter, AccumulateAdvance) {
// we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m'
// to get correct results we need to advance the command
auto check = [&](bool advance) {
Dbms dbms;
auto dba = dbms.active();
AstTreeStorage storage;
SymbolTable symbol_table;
auto node = NODE("n");
auto sym_n = symbol_table.CreateSymbol("n");
symbol_table[*node->identifier_] = sym_n;
auto create = std::make_shared<CreateNode>(node, nullptr);
auto accumulate = std::make_shared<Accumulate>(
create, std::vector<Symbol>{sym_n}, advance);
auto match = MakeScanAll(storage, symbol_table, "m", accumulate);
EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table));
};
check(false);
check(true);
}

View File

@ -0,0 +1,482 @@
//
// Copyright 2017 Memgraph
// Created by Florijan Stamenkovic on 14.03.17.
//
#include <iterator>
#include <memory>
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "communication/result_stream_faker.hpp"
#include "dbms/dbms.hpp"
#include "query/context.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/interpret/interpret.hpp"
#include "query/frontend/logical/operator.hpp"
#include "query_plan_common.hpp"
using namespace query;
using namespace query::plan;
TEST(QueryPlan, MatchReturn) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
dba->insert_vertex();
dba->insert_vertex();
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto scan_all = MakeScanAll(storage, symbol_table, "n");
auto output = NEXPR("n", IDENT("n"));
auto produce = MakeProduce(scan_all.op_, output);
symbol_table[*output->expression_] = scan_all.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 2);
}
TEST(QueryPlan, MatchReturnCartesian) {
Dbms dbms;
auto dba = dbms.active();
dba->insert_vertex().add_label(dba->label("l1"));
dba->insert_vertex().add_label(dba->label("l2"));
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto m = MakeScanAll(storage, symbol_table, "m", n.op_);
auto return_n = NEXPR("n", IDENT("n"));
symbol_table[*return_n->expression_] = n.sym_;
symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1");
auto return_m = NEXPR("m", IDENT("m"));
symbol_table[*return_m->expression_] = m.sym_;
symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2");
auto produce = MakeProduce(m.op_, return_n, return_m);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
auto result_data = result.GetResults();
EXPECT_EQ(result_data.size(), 4);
// ensure the result ordering is OK:
// "n" from the results is the same for the first two rows, while "m" isn't
EXPECT_EQ(result_data[0][0].Value<VertexAccessor>(),
result_data[1][0].Value<VertexAccessor>());
EXPECT_NE(result_data[0][1].Value<VertexAccessor>(),
result_data[1][1].Value<VertexAccessor>());
}
TEST(QueryPlan, StandaloneReturn) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
dba->insert_vertex();
dba->insert_vertex();
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto output = NEXPR("n", LITERAL(42));
auto produce = MakeProduce(std::shared_ptr<LogicalOperator>(nullptr), output);
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
EXPECT_EQ(result.GetResults()[0].size(), 1);
EXPECT_EQ(result.GetResults()[0][0].Value<int64_t>(), 42);
}
TEST(QueryPlan, NodeFilterLabelsAndProperties) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
GraphDbTypes::Label label = dba->label("Label");
GraphDbTypes::Property property = dba->property("Property");
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto v3 = dba->insert_vertex();
auto v4 = dba->insert_vertex();
auto v5 = dba->insert_vertex();
dba->insert_vertex();
// test all combination of (label | no_label) * (no_prop | wrong_prop |
// right_prop)
// only v1 will have the right labels
v1.add_label(label);
v2.add_label(label);
v3.add_label(label);
v1.PropsSet(property, 42);
v2.PropsSet(property, 1);
v4.PropsSet(property, 42);
v5.PropsSet(property, 1);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// make a scan all
auto n = MakeScanAll(storage, symbol_table, "n");
n.node_->labels_.emplace_back(label);
n.node_->properties_[property] = LITERAL(42);
// node filtering
auto node_filter = std::make_shared<NodeFilter>(n.op_, n.sym_, n.node_);
// make a named expression and a produce
auto output = NEXPR("x", IDENT("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(node_filter, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
}
TEST(QueryPlan, NodeFilterMultipleLabels) {
Dbms dbms;
auto dba = dbms.active();
// add a few nodes to the database
GraphDbTypes::Label label1 = dba->label("label1");
GraphDbTypes::Label label2 = dba->label("label2");
GraphDbTypes::Label label3 = dba->label("label3");
// the test will look for nodes that have label1 and label2
dba->insert_vertex(); // NOT accepted
dba->insert_vertex().add_label(label1); // NOT accepted
dba->insert_vertex().add_label(label2); // NOT accepted
dba->insert_vertex().add_label(label3); // NOT accepted
auto v1 = dba->insert_vertex(); // YES accepted
v1.add_label(label1);
v1.add_label(label2);
auto v2 = dba->insert_vertex(); // NOT accepted
v2.add_label(label1);
v2.add_label(label3);
auto v3 = dba->insert_vertex(); // YES accepted
v3.add_label(label1);
v3.add_label(label2);
v3.add_label(label3);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// make a scan all
auto n = MakeScanAll(storage, symbol_table, "n");
n.node_->labels_.emplace_back(label1);
n.node_->labels_.emplace_back(label2);
// node filtering
auto node_filter = std::make_shared<NodeFilter>(n.op_, n.sym_, n.node_);
// make a named expression and a produce
auto output = NEXPR("n", IDENT("n"));
auto produce = MakeProduce(node_filter, output);
// fill up the symbol table
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output->expression_] = n.sym_;
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 2);
}
TEST(QueryPlan, Expand) {
Dbms dbms;
auto dba = dbms.active();
// make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2)
auto v1 = dba->insert_vertex();
v1.add_label((GraphDbTypes::Label)1);
auto v2 = dba->insert_vertex();
v2.add_label((GraphDbTypes::Label)2);
auto v3 = dba->insert_vertex();
v3.add_label((GraphDbTypes::Label)3);
auto edge_type = dba->edge_type("Edge");
dba->insert_edge(v1, v2, edge_type);
dba->insert_edge(v1, v3, edge_type);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto test_expand = [&](EdgeAtom::Direction direction,
int expected_result_count) {
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction,
false, "m", false);
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
symbol_table[*output->expression_] = r_m.node_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(r_m.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), expected_result_count);
};
test_expand(EdgeAtom::Direction::RIGHT, 2);
test_expand(EdgeAtom::Direction::LEFT, 2);
test_expand(EdgeAtom::Direction::BOTH, 4);
}
TEST(QueryPlan, ExpandNodeCycle) {
Dbms dbms;
auto dba = dbms.active();
// make a graph (v1)->(v2) that
// has a recursive edge (v1)->(v1)
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto edge_type = dba->edge_type("Edge");
dba->insert_edge(v1, v1, edge_type);
dba->insert_edge(v1, v2, edge_type);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto test_cycle = [&](bool with_cycle, int expected_result_count) {
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::RIGHT, false, "n", with_cycle);
if (with_cycle)
symbol_table[*r_n.node_->identifier_] =
symbol_table[*n.node_->identifier_];
// make a named expression and a produce
auto output = NEXPR("n", IDENT("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(r_n.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), expected_result_count);
};
test_cycle(true, 1);
test_cycle(false, 2);
}
TEST(QueryPlan, ExpandEdgeCycle) {
Dbms dbms;
auto dba = dbms.active();
// make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2)
auto v1 = dba->insert_vertex();
v1.add_label((GraphDbTypes::Label)1);
auto v2 = dba->insert_vertex();
v2.add_label((GraphDbTypes::Label)2);
auto v3 = dba->insert_vertex();
v3.add_label((GraphDbTypes::Label)3);
auto edge_type = dba->edge_type("Edge");
dba->insert_edge(v1, v2, edge_type);
dba->insert_edge(v1, v3, edge_type);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto test_cycle = [&](bool with_cycle, int expected_result_count) {
auto i = MakeScanAll(storage, symbol_table, "i");
auto r_j = MakeExpand(storage, symbol_table, i.op_, i.sym_, "r",
EdgeAtom::Direction::BOTH, false, "j", false);
auto r_k = MakeExpand(storage, symbol_table, r_j.op_, r_j.node_sym_, "r",
EdgeAtom::Direction::BOTH, with_cycle, "k", false);
if (with_cycle)
symbol_table[*r_k.edge_->identifier_] =
symbol_table[*r_j.edge_->identifier_];
// make a named expression and a produce
auto output = NEXPR("r", IDENT("r"));
symbol_table[*output->expression_] = r_j.edge_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(r_k.op_, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), expected_result_count);
};
test_cycle(true, 4);
test_cycle(false, 6);
}
TEST(QueryPlan, EdgeFilter) {
Dbms dbms;
auto dba = dbms.active();
// make an N-star expanding from (v1)
// where only one edge will qualify
// and there are all combinations of
// (edge_type yes|no) * (property yes|absent|no)
std::vector<GraphDbTypes::EdgeType> edge_types;
for (int j = 0; j < 2; ++j)
edge_types.push_back(dba->edge_type("et" + std::to_string(j)));
std::vector<VertexAccessor> vertices;
for (int i = 0; i < 7; ++i) vertices.push_back(dba->insert_vertex());
GraphDbTypes::Property prop = dba->property("prop");
std::vector<EdgeAccessor> edges;
for (int i = 0; i < 6; ++i) {
edges.push_back(
dba->insert_edge(vertices[0], vertices[i + 1], edge_types[i % 2]));
switch (i % 3) {
case 0:
edges.back().PropsSet(prop, 42);
break;
case 1:
edges.back().PropsSet(prop, 100);
break;
default:
break;
}
}
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// define an operator tree for query
// MATCH (n)-[r]->(m) RETURN m
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::RIGHT, false, "m", false);
r_m.edge_->edge_types_.push_back(edge_types[0]);
r_m.edge_->properties_[prop] = LITERAL(42);
auto edge_filter =
std::make_shared<EdgeFilter>(r_m.op_, r_m.edge_sym_, r_m.edge_);
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
symbol_table[*output->expression_] = r_m.node_sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(edge_filter, output);
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 1);
}
TEST(QueryPlan, EdgeFilterMultipleTypes) {
Dbms dbms;
auto dba = dbms.active();
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto type_1 = dba->edge_type("type_1");
auto type_2 = dba->edge_type("type_2");
auto type_3 = dba->edge_type("type_3");
dba->insert_edge(v1, v2, type_1);
dba->insert_edge(v1, v2, type_2);
dba->insert_edge(v1, v2, type_3);
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
// make a scan all
auto n = MakeScanAll(storage, symbol_table, "n");
auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r",
EdgeAtom::Direction::RIGHT, false, "m", false);
// add a property filter
auto edge_filter =
std::make_shared<EdgeFilter>(r_m.op_, r_m.edge_sym_, r_m.edge_);
r_m.edge_->edge_types_.push_back(type_1);
r_m.edge_->edge_types_.push_back(type_2);
// make a named expression and a produce
auto output = NEXPR("m", IDENT("m"));
auto produce = MakeProduce(edge_filter, output);
// fill up the symbol table
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
symbol_table[*output->expression_] = r_m.node_sym_;
ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
EXPECT_EQ(result.GetResults().size(), 2);
}
TEST(QueryPlan, Filter) {
Dbms dbms;
auto dba = dbms.active();
// add a 6 nodes with property 'prop', 2 have true as value
GraphDbTypes::Property property = dba->property("Property");
for (int i = 0; i < 6; ++i)
dba->insert_vertex().PropsSet(property, i % 3 == 0);
dba->insert_vertex(); // prop not set, gives NULL
dba->advance_command();
AstTreeStorage storage;
SymbolTable symbol_table;
auto n = MakeScanAll(storage, symbol_table, "n");
auto e =
storage.Create<PropertyLookup>(storage.Create<Identifier>("n"), property);
symbol_table[*e->expression_] = n.sym_;
auto f = std::make_shared<Filter>(n.op_, e);
auto output =
storage.Create<NamedExpression>("x", storage.Create<Identifier>("n"));
symbol_table[*output->expression_] = n.sym_;
symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
auto produce = MakeProduce(f, output);
EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2);
}
TEST(QueryPlan, ExpandUniquenessFilter) {
Dbms dbms;
auto dba = dbms.active();
// make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1)
auto v1 = dba->insert_vertex();
auto v2 = dba->insert_vertex();
auto edge_type = dba->edge_type("edge_type");
dba->insert_edge(v1, v2, edge_type);
dba->insert_edge(v1, v1, edge_type);
dba->advance_command();
auto check_expand_results = [&](bool vertex_uniqueness,
bool edge_uniqueness) {
AstTreeStorage storage;
SymbolTable symbol_table;
auto n1 = MakeScanAll(storage, symbol_table, "n1");
auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1",
EdgeAtom::Direction::RIGHT, false, "n2", false);
std::shared_ptr<LogicalOperator> last_op = r1_n2.op_;
if (vertex_uniqueness)
last_op = std::make_shared<ExpandUniquenessFilter<VertexAccessor>>(
last_op, r1_n2.node_sym_, std::vector<Symbol>{n1.sym_});
auto r2_n3 =
MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2",
EdgeAtom::Direction::RIGHT, false, "n3", false);
last_op = r2_n3.op_;
if (edge_uniqueness)
last_op = std::make_shared<ExpandUniquenessFilter<EdgeAccessor>>(
last_op, r2_n3.edge_sym_, std::vector<Symbol>{r1_n2.edge_sym_});
if (vertex_uniqueness)
last_op = std::make_shared<ExpandUniquenessFilter<VertexAccessor>>(
last_op, r2_n3.node_sym_,
std::vector<Symbol>{n1.sym_, r1_n2.node_sym_});
return PullAll(last_op, *dba, symbol_table);
};
EXPECT_EQ(2, check_expand_results(false, false));
EXPECT_EQ(0, check_expand_results(true, false));
EXPECT_EQ(1, check_expand_results(false, true));
}