diff --git a/src/query/frontend/interpret/interpret.hpp b/src/query/frontend/interpret/interpret.hpp index 2f9da2d97..b04446c92 100644 --- a/src/query/frontend/interpret/interpret.hpp +++ b/src/query/frontend/interpret/interpret.hpp @@ -146,7 +146,7 @@ class ExpressionEvaluator : public TreeVisitorBase { } void Visit(Aggregation &aggregation) override { - auto value = frame_[symbol_table_[aggregation]]; + auto value = frame_[symbol_table_.at(aggregation)]; // Aggregation is probably always simple type, but let's switch accessor // just to be sure. SwitchAccessors(value); diff --git a/src/query/frontend/logical/operator.cpp b/src/query/frontend/logical/operator.cpp index 8813e2735..9fd551c91 100644 --- a/src/query/frontend/logical/operator.cpp +++ b/src/query/frontend/logical/operator.cpp @@ -1,3 +1,5 @@ +#include + #include "query/frontend/logical/operator.hpp" #include "query/exceptions.hpp" @@ -285,7 +287,6 @@ bool Expand::ExpandCursor::InitEdges(Frame &frame, // will need it). For now only Back expansion (left to right) is // supported // TODO add support for named paths - // TODO add support for uniqueness (edge, vertex) return true; } @@ -892,9 +893,11 @@ void ReconstructTypedValue(TypedValue &value) { for (auto &kv : value.Value>()) ReconstructTypedValue(kv.second); break; + case TypedValue::Type::Path: + // TODO implement path reconstruct? + throw NotYetImplemented("Path reconstruction not yet supported"); default: break; - // TODO implement path reconstruct? } } } @@ -946,8 +949,13 @@ bool Accumulate::AccumulateCursor::Pull(Frame &frame, Aggregate::Aggregate(const std::shared_ptr &input, const std::vector &aggregations, - const std::vector group_by) - : input_(input), aggregations_(aggregations), group_by_(group_by) {} + const std::vector &group_by, + const std::vector &remember, bool advance_command) + : input_(input), + aggregations_(aggregations), + group_by_(group_by), + remember_(remember), + advance_command_(advance_command) {} void Aggregate::Accept(LogicalOperatorVisitor &visitor) { if (visitor.PreVisit(*this)) { @@ -958,7 +966,182 @@ void Aggregate::Accept(LogicalOperatorVisitor &visitor) { } std::unique_ptr Aggregate::MakeCursor(GraphDbAccessor &db) { - return std::unique_ptr(); + return std::make_unique(*this, db); +} + +Aggregate::AggregateCursor::AggregateCursor(Aggregate &self, + GraphDbAccessor &db) + : self_(self), db_(db), input_cursor_(self_.input_->MakeCursor(db)) {} + +bool Aggregate::AggregateCursor::Pull(Frame &frame, + const SymbolTable &symbol_table) { + if (!pulled_all_input_) { + PullAllInput(frame, symbol_table); + pulled_all_input_ = true; + aggregation_it_ = aggregation_.begin(); + + if (self_.advance_command_) { + db_.advance_command(); + // regarding reconstruction after advance_command + // we have to reconstruct only the remember values + // because aggregation results are primitives and + // group-by elements won't be used directly (possibly re-evaluated + // using remember values) + for (auto &kv : aggregation_) + for (TypedValue &remember : kv.second.remember_) + ReconstructTypedValue(remember); + } + } + + if (aggregation_it_ == aggregation_.end()) return false; + + // place aggregation values on the frame + auto aggregation_values_it = aggregation_it_->second.values_.begin(); + for (const auto &aggregation_elem : self_.aggregations_) + frame[std::get<2>(aggregation_elem)] = *aggregation_values_it++; + + // place remember values on the frame + auto remember_values_it = aggregation_it_->second.remember_.begin(); + for (const Symbol &remember_sym : self_.remember_) + frame[remember_sym] = *remember_values_it++; + + aggregation_it_++; + return true; +} + +void Aggregate::AggregateCursor::PullAllInput(Frame &frame, + const SymbolTable &symbol_table) { + ExpressionEvaluator evaluator(frame, symbol_table); + evaluator.SwitchNew(); + + while (input_cursor_->Pull(frame, symbol_table)) { + // create the group-by list of values + std::list group_by; + for (Expression *expression : self_.group_by_) { + expression->Accept(evaluator); + group_by.emplace_back(evaluator.PopBack()); + } + + AggregationValue &agg_value = aggregation_[group_by]; + EnsureInitialized(frame, agg_value); + Update(frame, symbol_table, evaluator, agg_value); + } + + // calculate AVG aggregations (so far they have only been summed) + for (int pos = 0; pos < self_.aggregations_.size(); ++pos) { + if (std::get<1>(self_.aggregations_[pos]) != Aggregation::Op::AVG) continue; + for (auto &kv : aggregation_) { + AggregationValue &agg_value = kv.second; + int count = agg_value.counts_[pos]; + if (count > 0) + agg_value.values_[pos] = agg_value.values_[pos] / (double)count; + } + } +} + +void Aggregate::AggregateCursor::EnsureInitialized( + Frame &frame, Aggregate::AggregateCursor::AggregationValue &agg_value) { + if (agg_value.values_.size() > 0) return; + + agg_value.values_.resize(self_.aggregations_.size(), TypedValue::Null); + agg_value.counts_.resize(self_.aggregations_.size(), 0); + + for (const Symbol &remember_sym : self_.remember_) + agg_value.remember_.push_back(frame[remember_sym]); +} + +void Aggregate::AggregateCursor::Update( + Frame &frame, const SymbolTable &symbol_table, + ExpressionEvaluator &evaluator, + Aggregate::AggregateCursor::AggregationValue &agg_value) { + debug_assert(self_.aggregations_.size() == agg_value.values_.size(), + "Inappropriate AggregationValue.values_ size"); + debug_assert(self_.aggregations_.size() == agg_value.counts_.size(), + "Inappropriate AggregationValue.counts_ size"); + + // we iterate over counts, values and aggregation info at the same time + auto count_it = agg_value.counts_.begin(); + auto value_it = agg_value.values_.begin(); + auto agg_elem_it = self_.aggregations_.begin(); + for (; count_it < agg_value.counts_.end(); + count_it++, value_it++, agg_elem_it++) { + std::get<0>(*agg_elem_it)->Accept(evaluator); + TypedValue input_value = evaluator.PopBack(); + + if (input_value.type() == TypedValue::Type::Null) continue; + + *count_it += 1; + if (*count_it == 1) { + // first value, nothing to aggregate. set and continue. + *value_it = input_value; + continue; + } + + // aggregation of existing values + switch (std::get<1>(*agg_elem_it)) { + case Aggregation::Op::COUNT: + *value_it = *count_it; + break; + case Aggregation::Op::MIN: { + EnsureOkForMinMax(input_value); + // TODO an illegal comparison here will throw a TypedValueException + // consider catching and throwing something else + TypedValue comparison_result = input_value < *value_it; + // since we skip nulls we either have a valid comparison, or + // an exception was just thrown above + // safe to assume a bool TypedValue + if (comparison_result.Value()) *value_it = input_value; + break; + } + case Aggregation::Op::MAX: { + // all comments as for Op::Min + EnsureOkForMinMax(input_value); + TypedValue comparison_result = input_value > *value_it; + if (comparison_result.Value()) *value_it = input_value; + break; + } + case Aggregation::Op::AVG: + // for averaging we sum first and divide by count once all + // the input has been processed + case Aggregation::Op::SUM: + EnsureOkForAvgSum(input_value); + *value_it = *value_it + input_value; + break; + } // end switch over Aggregation::Op enum + } // end loop over all aggregations +} + +void Aggregate::AggregateCursor::EnsureOkForMinMax(const TypedValue &value) { + switch (value.type()) { + case TypedValue::Type::Bool: + case TypedValue::Type::Int: + case TypedValue::Type::Double: + case TypedValue::Type::String: + return; + default: + // TODO consider better error feedback + throw TypedValueException( + "Only Bool, Int, Double and String properties are allowed in " + "MIN and MAX aggregations"); + } +} +void Aggregate::AggregateCursor::EnsureOkForAvgSum(const TypedValue &value) { + switch (value.type()) { + case TypedValue::Type::Int: + case TypedValue::Type::Double: + return; + default: + // TODO consider better error feedback + throw TypedValueException( + "Only numeric properties allowed in SUM and AVG aggregations"); + } +} + +bool Aggregate::AggregateCursor::TypedValueListEqual::operator()( + const std::list &left, + const std::list &right) const { + return std::equal(left.begin(), left.end(), right.begin(), + TypedValue::BoolEqual{}); } } // namespace plan diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp index 69bce2efe..d57a3ef7b 100644 --- a/src/query/frontend/logical/operator.hpp +++ b/src/query/frontend/logical/operator.hpp @@ -2,12 +2,16 @@ #pragma once +#include #include +#include +#include #include #include "database/graph_db_accessor.hpp" #include "database/graph_db_datatypes.hpp" #include "query/frontend/semantic/symbol_table.hpp" +#include "utils/hashing/fnv.hpp" #include "utils/visitor/visitable.hpp" #include "utils/visitor/visitor.hpp" @@ -824,14 +828,94 @@ class Aggregate : public LogicalOperator { Aggregate(const std::shared_ptr &input, const std::vector &aggregations, - const std::vector group_by); + const std::vector &group_by, + const std::vector &remember, bool advance_command = false); void Accept(LogicalOperatorVisitor &visitor) override; std::unique_ptr MakeCursor(GraphDbAccessor &db) override; private: const std::shared_ptr input_; const std::vector aggregations_; - const std::vector group_by_; + const std::vector group_by_; + const std::vector remember_; + const bool advance_command_; + + class AggregateCursor : public Cursor { + public: + AggregateCursor(Aggregate &self, GraphDbAccessor &db); + bool Pull(Frame &frame, const SymbolTable &symbol_table) override; + + private: + // custom equality function for the unordered map + struct TypedValueListEqual { + bool operator()(const std::list &left, + const std::list &right) const; + }; + + // Data structure for a single aggregation cache. + // does NOT include the group-by values since those + // are a key in the aggregation map. + // The vectors in an AggregationValue contain one element + // for each aggregation in this LogicalOp. + struct AggregationValue { + // how many input rows has been aggregated in respective + // values_ element so far + std::vector counts_; + // aggregated values. Initially Null (until at least one + // input row with a valid value gets processed) + std::vector values_; + // remember values. + std::vector remember_; + }; + + Aggregate &self_; + GraphDbAccessor &db_; + std::unique_ptr input_cursor_; + // storage for aggregated data + // map key is the list of group-by values + // map value is an AggregationValue struct + std::unordered_map< + std::list, AggregationValue, + // use FNV collection hashing specialized for a list of TypedValues + FnvCollection, TypedValue, TypedValue::Hash>, + // custom equality + TypedValueListEqual> + aggregation_; + // iterator over the accumulated cache + decltype(aggregation_.begin()) aggregation_it_ = aggregation_.begin(); + // this LogicalOp pulls all from the input on it's first pull + // this switch tracks if this has been performed + bool pulled_all_input_{false}; + + /** + * Pulls from the input until it's exhausted. Accumulates the results + * in the `aggregation_` map. + * + * Accumulation automatically groups the results so that `aggregation_` + * cache cardinality depends on number of + * aggregation results, and not on the number of inputs. + */ + void PullAllInput(Frame &frame, const SymbolTable &symbolTable); + + /** Ensures the new AggregationValue has been initialized. This means + * that the value vectors are filled with an appropriate number of Nulls, + * counts are set to 0 and remember values are remembered. + */ + void EnsureInitialized(Frame &frame, AggregationValue &agg_value); + + /** Updates the given AggregationValue with new data. Assumes that + * the AggregationValue has been initialized */ + void Update(Frame &frame, const SymbolTable &symbol_table, + ExpressionEvaluator &evaluator, AggregationValue &agg_value); + + /** Checks if the given TypedValue is legal in MIN and MAX. If not + * an appropriate exception is thrown. */ + void EnsureOkForMinMax(const TypedValue &value); + + /** Checks if the given TypedValue is legal in AVG and SUM. If not + * an appropriate exception is thrown. */ + void EnsureOkForAvgSum(const TypedValue &value); + }; }; } // namespace plan diff --git a/src/query/typed_value.cpp b/src/query/typed_value.cpp index 6d9a2f8e1..c7f177847 100644 --- a/src/query/typed_value.cpp +++ b/src/query/typed_value.cpp @@ -6,6 +6,8 @@ #include #include "utils/assert.hpp" +#include "utils/exceptions/not_yet_implemented.hpp" +#include "utils/hashing/fnv.hpp" namespace query { @@ -441,6 +443,10 @@ TypedValue operator==(const TypedValue &a, const TypedValue &b) { if (a.type() == TypedValue::Type::Null || b.type() == TypedValue::Type::Null) return TypedValue::Null; + if (a.type() == TypedValue::Type::Map || b.type() == TypedValue::Type::Map) { + throw NotYetImplemented(); + } + if (a.type() == TypedValue::Type::List || b.type() == TypedValue::Type::List) { if (a.type() == TypedValue::Type::List && @@ -731,4 +737,66 @@ TypedValue operator^(const TypedValue &a, const TypedValue &b) { } } +bool TypedValue::BoolEqual::operator()(const TypedValue &lhs, + const TypedValue &rhs) const { + if (lhs.type() == TypedValue::Type::Null && + rhs.type() == TypedValue::Type::Null) + return true; + + // legal comparisons are only between same types + // only int -> float is promoted + if (lhs.type() == rhs.type() || + (lhs.type() == Type::Double && rhs.type() == Type::Int) || + (rhs.type() == Type::Double && lhs.type() == Type::Int)) + { + TypedValue equality_result = lhs == rhs; + switch (equality_result.type()) { + case TypedValue::Type::Bool: + return equality_result.Value(); + case TypedValue::Type::Null: + // we already tested if both operands are null, + // so only one is null here. this evaluates to false equality + return false; + default: + permanent_fail( + "Equality between two TypedValues resulted in something other " + "then Null or bool"); + } + } + + return false; +} + +size_t TypedValue::Hash::operator()(const TypedValue &value) const { + switch (value.type()) { + case TypedValue::Type::Null: + return 31; + case TypedValue::Type::Bool: + return std::hash{}(value.Value()); + case TypedValue::Type::Int: + // we cast int to double for hashing purposes + // to be consistent with TypedValue equality + // in which (2.0 == 2) returns true + return std::hash{}((double)value.Value()); + case TypedValue::Type::Double: + return std::hash{}(value.Value()); + case TypedValue::Type::String: + return std::hash{}(value.Value()); + case TypedValue::Type::List: { + return FnvCollection, TypedValue, Hash>{}( + value.Value>()); + } + case TypedValue::Type::Map: + throw NotYetImplemented(); + case TypedValue::Type::Vertex: + return value.Value().temporary_id(); + case TypedValue::Type::Edge: + return value.Value().temporary_id(); + case TypedValue::Type::Path: + throw NotYetImplemented(); + break; + } + permanent_fail("Unhandled TypedValue.type() in hash function"); +} + } // namespace query diff --git a/src/query/typed_value.hpp b/src/query/typed_value.hpp index cc79aee41..e58a18732 100644 --- a/src/query/typed_value.hpp +++ b/src/query/typed_value.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "storage/edge_accessor.hpp" @@ -29,6 +30,34 @@ typedef traversal_template::Path Path; */ class TypedValue : public TotalOrdering { public: + /** Custom TypedValue equality function that returns a bool + * (as opposed to returning TypedValue as the default equality does). + * This implementation treats two nulls as being equal and null + * not being equal to everything else. + */ + struct BoolEqual { + bool operator()(const TypedValue &left, const TypedValue &right) const; + }; + + /** Hash operator for TypedValue. + * + * Not injecting into std + * due to linking problems. If the implementation is in this header, + * then it implicitly instantiates TypedValue::Value before + * explicit instantiation in .cpp file. If the implementation is in + * the .cpp file, it won't link. + */ + struct Hash { + size_t operator()(const TypedValue &value) const; + }; + + /** + * Unordered set of TypedValue items. Can contain at most one Null element, + * and treats an integral and floating point value as same if they are equal + * in the floating point domain (TypedValue operator== behaves the same). + * */ + using unordered_set = std::unordered_set; + /** Private default constructor, makes Null */ TypedValue() : type_(Type::Null) {} @@ -163,7 +192,6 @@ TypedValue operator||(const TypedValue &a, const TypedValue &b); // Be careful: since ^ is binary operator and || and && are logical operators // they have different priority in c++. TypedValue operator^(const TypedValue &a, const TypedValue &b); - // stream output std::ostream &operator<<(std::ostream &os, const TypedValue::Type type); diff --git a/src/utils/hashing/fnv.hpp b/src/utils/hashing/fnv.hpp index 49561dbd7..ed586ebeb 100644 --- a/src/utils/hashing/fnv.hpp +++ b/src/utils/hashing/fnv.hpp @@ -15,21 +15,53 @@ namespace { #ifdef MEMGRAPH64 -template -uint64_t fnv(const T& data) { - return fnv1a64(data); +__attribute__((unused)) uint64_t fnv(const std::string &s) { + return fnv1a64(s); } using HashType = uint64_t; #elif -template -uint32_t fnv(const T& data) { - return fnv1a32(data); +__attribute__((unused)) uint32_t fnv(const std::string &s) { + return fnv1a32(s); } using HashType = uint32_t; #endif } + +/** + * Does FNV-like hashing on a collection. Not truly FNV + * because it operates on 8-bit elements, while this + * implementation uses size_t elements (collection item + * hash). + * + * https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function + * + * + * @tparam TIterable A collection type that has begin() and end(). + * @tparam TElement Type of element in the collection. + * @tparam THash Hash type (has operator() that accepts a 'const TEelement &' + * and returns size_t. Defaults to std::hash. + * @param iterable A collection of elements. + * @param element_hash Function for hashing a single element. + * @return The hash of the whole collection. + */ +template > +struct FnvCollection { + size_t operator()(const TIterable &iterable) const { + uint64_t hash = 14695981039346656037u; + THash element_hash; + for (const TElement &element : iterable) { + hash *= fnv_prime; + hash ^= element_hash(element); + } + return hash; + } + + private: + static const uint64_t fnv_prime = 1099511628211u; +}; diff --git a/src/utils/hashing/fnv32.hpp b/src/utils/hashing/fnv32.hpp index 675b26605..2dfc5d75e 100644 --- a/src/utils/hashing/fnv32.hpp +++ b/src/utils/hashing/fnv32.hpp @@ -8,41 +8,21 @@ namespace { #define OFFSET_BASIS32 2166136261u #define FNV_PRIME32 16777619u -uint32_t fnv32(const unsigned char* const data, size_t n) { +__attribute__((unused)) uint32_t fnv32(const std::string &s) { uint32_t hash = OFFSET_BASIS32; - for (size_t i = 0; i < n; ++i) - hash = (hash * FNV_PRIME32) xor (uint32_t) data[i]; + for (size_t i = 0; i < s.size(); ++i) + hash = (hash * FNV_PRIME32) xor (uint32_t) s[i]; return hash; } -template -uint32_t fnv32(const T& data) { - return fnv32(&data, sizeof(data)); -} - -template <> -__attribute__((unused)) uint32_t fnv32(const std::string& data) { - return fnv32((const unsigned char*)data.c_str(), data.size()); -} - -uint32_t fnv1a32(const unsigned char* const data, size_t n) { +__attribute__((unused)) uint32_t fnv1a32(const std::string &s) { uint32_t hash = OFFSET_BASIS32; - for (size_t i = 0; i < n; ++i) - hash = (hash xor (uint32_t) data[i]) * FNV_PRIME32; + for (size_t i = 0; i < s.size(); ++i) + hash = (hash xor (uint32_t) s[i]) * FNV_PRIME32; return hash; } - -template -uint32_t fnv1a32(const T& data) { - return fnv1a32(&data, sizeof(data)); -} - -template <> -__attribute__((unused)) uint32_t fnv1a32(const std::string& data) { - return fnv1a32((const unsigned char*)data.c_str(), data.size()); -} } diff --git a/src/utils/hashing/fnv64.hpp b/src/utils/hashing/fnv64.hpp index 4cbf27b1f..a93eb3147 100644 --- a/src/utils/hashing/fnv64.hpp +++ b/src/utils/hashing/fnv64.hpp @@ -8,41 +8,21 @@ namespace { #define OFFSET_BASIS64 14695981039346656037u #define FNV_PRIME64 1099511628211u -uint64_t fnv64(const unsigned char* const data, size_t n) { +__attribute__((unused)) uint64_t fnv64(const std::string &s) { uint64_t hash = OFFSET_BASIS64; - for (size_t i = 0; i < n; ++i) - hash = (hash * FNV_PRIME64) xor (uint64_t) data[i]; + for (size_t i = 0; i < s.size(); ++i) + hash = (hash * FNV_PRIME64) xor (uint64_t) s[i]; return hash; } -template -uint64_t fnv64(const T& data) { - return fnv64(&data, sizeof(data)); -} - -template <> -__attribute__((unused)) uint64_t fnv64(const std::string& data) { - return fnv64((const unsigned char*)data.c_str(), data.size()); -} - -uint64_t fnv1a64(const unsigned char* const data, size_t n) { +__attribute__((unused)) uint64_t fnv1a64(const std::string &s) { uint64_t hash = OFFSET_BASIS64; - for (size_t i = 0; i < n; ++i) - hash = (hash xor (uint64_t) data[i]) * FNV_PRIME64; + for (size_t i = 0; i < s.size(); ++i) + hash = (hash xor (uint64_t) s[i]) * FNV_PRIME64; return hash; } - -template -uint64_t fnv1a64(const T& data) { - return fnv1a64(&data, sizeof(data)); -} - -template <> -__attribute__((unused)) uint64_t fnv1a64(const std::string& data) { - return fnv1a64((const unsigned char*)data.c_str(), data.size()); -} } diff --git a/tests/benchmark/data_structures/bloom/basic_bloom_filter.cpp b/tests/benchmark/data_structures/bloom/basic_bloom_filter.cpp index 90a38cf22..e5d9d88ce 100644 --- a/tests/benchmark/data_structures/bloom/basic_bloom_filter.cpp +++ b/tests/benchmark/data_structures/bloom/basic_bloom_filter.cpp @@ -38,8 +38,8 @@ int main(int argc, char **argv) { auto elements = utils::random::generate_vector(generator, 1 << 16); - StringHashFunction hash1 = fnv64; - StringHashFunction hash2 = fnv1a64; + StringHashFunction hash1 = fnv64; + StringHashFunction hash2 = fnv1a64; std::vector funcs = {hash1, hash2}; BloomFilter bloom(funcs); diff --git a/tests/benchmark/data_structures/concurrent/bloom_map_concurrent.cpp b/tests/benchmark/data_structures/concurrent/bloom_map_concurrent.cpp index 91d3023e4..0f85a207c 100644 --- a/tests/benchmark/data_structures/concurrent/bloom_map_concurrent.cpp +++ b/tests/benchmark/data_structures/concurrent/bloom_map_concurrent.cpp @@ -129,8 +129,8 @@ int main(int argc, char **argv) { PairGenerator psig(&sg, &ig); PairGenerator pisg(&ig, &sg); - StringHashFunction hash1 = fnv64; - StringHashFunction hash2 = fnv1a64; + StringHashFunction hash1 = fnv64; + StringHashFunction hash2 = fnv1a64; std::vector funcs = {hash1, hash2}; BloomFilter bloom_filter_(funcs); diff --git a/tests/unit/basic_bloom_filter.cpp b/tests/unit/basic_bloom_filter.cpp index 714dcb411..9cf57ad97 100644 --- a/tests/unit/basic_bloom_filter.cpp +++ b/tests/unit/basic_bloom_filter.cpp @@ -9,8 +9,8 @@ using StringHashFunction = std::function; TEST(BloomFilterTest, InsertContains) { - StringHashFunction hash1 = fnv64; - StringHashFunction hash2 = fnv1a64; + StringHashFunction hash1 = fnv64; + StringHashFunction hash2 = fnv1a64; std::vector funcs = {hash1, hash2}; BloomFilter bloom(funcs); diff --git a/tests/unit/query_plan_accumulate_aggregate.cpp b/tests/unit/query_plan_accumulate_aggregate.cpp new file mode 100644 index 000000000..1137c7536 --- /dev/null +++ b/tests/unit/query_plan_accumulate_aggregate.cpp @@ -0,0 +1,382 @@ +// +// Copyright 2017 Memgraph +// Created by Florijan Stamenkovic on 14.03.17. +// + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "communication/result_stream_faker.hpp" +#include "dbms/dbms.hpp" +#include "query/context.hpp" +#include "query/exceptions.hpp" +#include "query/frontend/interpret/interpret.hpp" +#include "query/frontend/logical/operator.hpp" + +#include "query_plan_common.hpp" + +using namespace query; +using namespace query::plan; + +TEST(QueryPlan, Accumulate) { + // simulate the following two query execution on an empty db + // CREATE ({x:0})-[:T]->({x:0}) + // MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x + // without accumulation we expected results to be [[1, 1], [2, 2]] + // with accumulation we expect them to be [[2, 2], [2, 2]] + + auto check = [&](bool accumulate) { + Dbms dbms; + auto dba = dbms.active(); + auto prop = dba->property("x"); + + auto v1 = dba->insert_vertex(); + v1.PropsSet(prop, 0); + auto v2 = dba->insert_vertex(); + v2.PropsSet(prop, 0); + dba->insert_edge(v1, v2, dba->edge_type("T")); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", + EdgeAtom::Direction::BOTH, false, "m", false); + + auto one = LITERAL(1); + auto n_p = PROPERTY_LOOKUP("n", prop); + symbol_table[*n_p->expression_] = n.sym_; + auto set_n_p = + std::make_shared(r_m.op_, n_p, ADD(n_p, one)); + auto m_p = PROPERTY_LOOKUP("m", prop); + symbol_table[*m_p->expression_] = r_m.node_sym_; + auto set_m_p = + std::make_shared(set_n_p, m_p, ADD(m_p, one)); + + std::shared_ptr last_op = set_m_p; + if (accumulate) { + last_op = std::make_shared( + last_op, std::vector{n.sym_, r_m.node_sym_}); + } + + auto n_p_ne = NEXPR("n.p", n_p); + symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne"); + auto m_p_ne = NEXPR("m.p", m_p); + symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne"); + auto produce = MakeProduce(last_op, n_p_ne, m_p_ne); + ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba); + std::vector results_data; + for (const auto &row : results.GetResults()) + for (const auto &column : row) + results_data.emplace_back(column.Value()); + if (accumulate) + EXPECT_THAT(results_data, testing::ElementsAre(2, 2, 2, 2)); + else + EXPECT_THAT(results_data, testing::ElementsAre(1, 1, 2, 2)); + }; + + check(false); + check(true); +} + +TEST(QueryPlan, AccumulateAdvance) { + // we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m' + // to get correct results we need to advance the command + + auto check = [&](bool advance) { + Dbms dbms; + auto dba = dbms.active(); + AstTreeStorage storage; + SymbolTable symbol_table; + + auto node = NODE("n"); + auto sym_n = symbol_table.CreateSymbol("n"); + symbol_table[*node->identifier_] = sym_n; + auto create = std::make_shared(node, nullptr); + auto accumulate = std::make_shared( + create, std::vector{sym_n}, advance); + auto match = MakeScanAll(storage, symbol_table, "m", accumulate); + EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table)); + }; + check(false); + check(true); +} + +std::shared_ptr MakeAggregationProduce( + std::shared_ptr input, SymbolTable &symbol_table, + AstTreeStorage &storage, const std::vector aggr_inputs, + const std::vector aggr_ops, + const std::vector group_by_exprs, + const std::vector remember) { + permanent_assert(aggr_inputs.size() == aggr_ops.size(), + "Provide as many aggr inputs as aggr ops"); + // prepare all the aggregations + std::vector aggregates; + std::vector named_expressions; + + auto aggr_inputs_it = aggr_inputs.begin(); + for (auto aggr_op : aggr_ops) { + // TODO change this from using IDENT to using AGGREGATION + // once AGGREGATION is handled properly in ExpressionEvaluation + auto named_expr = NEXPR("", IDENT("aggregation")); + named_expressions.push_back(named_expr); + symbol_table[*named_expr->expression_] = + symbol_table.CreateSymbol("aggregation"); + symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression"); + aggregates.emplace_back(*aggr_inputs_it++, aggr_op, + symbol_table[*named_expr->expression_]); + } + + // Produce will also evaluate group_by expressions + // and return them after the aggregations + for (auto group_by_expr : group_by_exprs) { + auto named_expr = NEXPR("", group_by_expr); + named_expressions.push_back(named_expr); + symbol_table[*named_expr] = symbol_table.CreateSymbol("named_expression"); + } + auto aggregation = + std::make_shared(input, aggregates, group_by_exprs, remember); + return std::make_shared(aggregation, named_expressions); +} + +TEST(QueryPlan, AggregateOps) { + Dbms dbms; + auto dba = dbms.active(); + + // setup is several nodes most of which have an int property set + // we will take the sum, avg, min, max and count + // we won't group by anything + auto prop = dba->property("prop"); + dba->insert_vertex().PropsSet(prop, 4); + dba->insert_vertex().PropsSet(prop, 7); + dba->insert_vertex().PropsSet(prop, 12); + // a missing property (null) gets ignored by all aggregations + dba->insert_vertex(); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP("n", prop); + symbol_table[*n_p->expression_] = n.sym_; + + auto produce = MakeAggregationProduce( + n.op_, symbol_table, storage, std::vector(5, n_p), + {Aggregation::Op::COUNT, Aggregation::Op::MIN, Aggregation::Op::MAX, + Aggregation::Op::SUM, Aggregation::Op::AVG}, + {}, {}); + + // checks + auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0].size(), 5); + // count + ASSERT_EQ(results[0][0].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][0].Value(), 3); + // min + ASSERT_EQ(results[0][1].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][1].Value(), 4); + // max + ASSERT_EQ(results[0][2].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][2].Value(), 12); + // sum + ASSERT_EQ(results[0][3].type(), TypedValue::Type::Int); + EXPECT_EQ(results[0][3].Value(), 23); + // avg + ASSERT_EQ(results[0][4].type(), TypedValue::Type::Double); + EXPECT_FLOAT_EQ(results[0][4].Value(), 23 / 3.0); +} + +TEST(QueryPlan, AggregateGroupByValues) { + // tests that distinct groups are aggregated properly + // for values of all types + // also test the "remember" part of the Aggregation API + // as final results are obtained via a property lookup of + // a remembered node + Dbms dbms; + auto dba = dbms.active(); + + // a vector of TypedValue to be set as property values on vertices + // most of them should result in a distinct group (commented where not) + std::vector group_by_vals; + group_by_vals.emplace_back(4); + group_by_vals.emplace_back(7); + group_by_vals.emplace_back(7.3); + group_by_vals.emplace_back(7.2); + group_by_vals.emplace_back("Johhny"); + group_by_vals.emplace_back("Jane"); + group_by_vals.emplace_back("1"); + group_by_vals.emplace_back(true); + group_by_vals.emplace_back(false); + group_by_vals.emplace_back(std::vector{1}); + group_by_vals.emplace_back(std::vector{1, 2}); + group_by_vals.emplace_back(std::vector{2, 1}); + group_by_vals.emplace_back(TypedValue::Null); + // should NOT result in another group because 7.0 == 7 + group_by_vals.emplace_back(7.0); + // should NOT result in another group + group_by_vals.emplace_back(std::vector{1, 2.0}); + + // generate a lot of vertices and set props on them + auto prop = dba->property("prop"); + for (int i = 0; i < 1000; ++i) + dba->insert_vertex().PropsSet(prop, + group_by_vals[i % group_by_vals.size()]); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p = PROPERTY_LOOKUP("n", prop); + symbol_table[*n_p->expression_] = n.sym_; + + auto produce = + MakeAggregationProduce(n.op_, symbol_table, storage, {n_p}, + {Aggregation::Op::COUNT}, {n_p}, {n.sym_}); + + auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); + ASSERT_EQ(results.size(), group_by_vals.size() - 2); + TypedValue::unordered_set result_group_bys; + for (const auto &row : results) { + ASSERT_EQ(2, row.size()); + result_group_bys.insert(row[1]); + } + ASSERT_EQ(result_group_bys.size(), group_by_vals.size() - 2); + EXPECT_TRUE(std::is_permutation( + group_by_vals.begin(), group_by_vals.end() - 2, result_group_bys.begin(), + TypedValue::BoolEqual{})); +} + +TEST(QueryPlan, AggregateMultipleGroupBy) { + // in this test we have 3 different properties that have different values + // for different records and assert that we get the correct combination + // of values in our groups + Dbms dbms; + auto dba = dbms.active(); + + auto prop1 = dba->property("prop1"); + auto prop2 = dba->property("prop2"); + auto prop3 = dba->property("prop3"); + for (int i = 0; i < 2 * 3 * 5; ++i) { + auto v = dba->insert_vertex(); + v.PropsSet(prop1, (bool)(i % 2)); + v.PropsSet(prop2, i % 3); + v.PropsSet(prop3, "value" + std::to_string(i % 5)); + } + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // match all nodes and perform aggregations + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP("n", prop1); + auto n_p2 = PROPERTY_LOOKUP("n", prop2); + auto n_p3 = PROPERTY_LOOKUP("n", prop3); + symbol_table[*n_p1->expression_] = n.sym_; + symbol_table[*n_p2->expression_] = n.sym_; + symbol_table[*n_p3->expression_] = n.sym_; + + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, {n_p1}, + {Aggregation::Op::COUNT}, + {n_p1, n_p2, n_p3}, {n.sym_}); + + auto results = CollectProduce(produce, symbol_table, *dba).GetResults(); + EXPECT_EQ(results.size(), 2 * 3 * 5); +} + +TEST(QueryPlan, AggregateAdvance) { + // we simulate 'CREATE (n {x: 42}) WITH count(n.x) AS c MATCH (m) RETURN m, + // m.x, c' + // to get correct results we need to advance the command in aggregation + // since we only test aggregation, we'll simplify the logical plan and only + // check the count and not all the results + + auto check = [&](bool advance) { + Dbms dbms; + auto dba = dbms.active(); + AstTreeStorage storage; + SymbolTable symbol_table; + + auto node = NODE("n"); + auto sym_n = symbol_table.CreateSymbol("n"); + symbol_table[*node->identifier_] = sym_n; + auto create = std::make_shared(node, nullptr); + + auto aggr_sym = symbol_table.CreateSymbol("aggr_sym"); + auto n_p = PROPERTY_LOOKUP("n", dba->property("x")); + symbol_table[*n_p->expression_] = sym_n; + auto aggregate = std::make_shared( + create, std::vector{Aggregate::Element{ + n_p, Aggregation::Op::COUNT, aggr_sym}}, + std::vector{}, std::vector{}, advance); + auto match = MakeScanAll(storage, symbol_table, "m", aggregate); + EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table)); + }; +// check(false); + check(true); +} + +TEST(QueryPlan, AggregateTypes) { + // testing exceptions that can get emitted by an aggregation + // does not check all combinations that can result in an exception + // (that logic is defined and tested by TypedValue) + + Dbms dbms; + auto dba = dbms.active(); + + auto p1 = dba->property("p1"); // has only string props + dba->insert_vertex().PropsSet(p1, "string"); + dba->insert_vertex().PropsSet(p1, "str2"); + auto p2 = dba->property("p2"); // combines int and bool + dba->insert_vertex().PropsSet(p2, 42); + dba->insert_vertex().PropsSet(p2, true); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto n_p1 = PROPERTY_LOOKUP("n", p1); + symbol_table[*n_p1->expression_] = n.sym_; + auto n_p2 = PROPERTY_LOOKUP("n", p2); + symbol_table[*n_p2->expression_] = n.sym_; + + auto aggregate = [&](Expression *expression, Aggregation::Op aggr_op) { + auto produce = MakeAggregationProduce(n.op_, symbol_table, storage, + {expression}, {aggr_op}, {}, {}); + CollectProduce(produce, symbol_table, *dba).GetResults(); + }; + + // everythin except for COUNT fails on a Vertex + auto n_id = n_p1->expression_; + aggregate(n_id, Aggregation::Op::COUNT); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MIN), TypedValueException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::MAX), TypedValueException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::AVG), TypedValueException); + EXPECT_THROW(aggregate(n_id, Aggregation::Op::SUM), TypedValueException); + + // on strings AVG and SUM fail + aggregate(n_p1, Aggregation::Op::COUNT); + aggregate(n_p1, Aggregation::Op::MIN); + aggregate(n_p1, Aggregation::Op::MAX); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::AVG), TypedValueException); + EXPECT_THROW(aggregate(n_p1, Aggregation::Op::SUM), TypedValueException); + + // combination of int and bool, everything except count fails + aggregate(n_p2, Aggregation::Op::COUNT); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MIN), TypedValueException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::MAX), TypedValueException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::AVG), TypedValueException); + EXPECT_THROW(aggregate(n_p2, Aggregation::Op::SUM), TypedValueException); +} diff --git a/tests/unit/query_plan_common.hpp b/tests/unit/query_plan_common.hpp new file mode 100644 index 000000000..0d1fbe774 --- /dev/null +++ b/tests/unit/query_plan_common.hpp @@ -0,0 +1,128 @@ +// +// Copyright 2017 Memgraph +// Created by Florijan Stamenkovic on 14.03.17. +// + +#include +#include +#include + +#include "query/frontend/semantic/symbol_table.hpp" +#include "query/frontend/interpret/interpret.hpp" + +#include "query_common.hpp" + +using namespace query; +using namespace query::plan; + +/** + * Helper function that collects all the results from the given + * Produce into a ResultStreamFaker and returns that object. + * + * @param produce + * @param symbol_table + * @param db_accessor + * @return + */ +auto CollectProduce(std::shared_ptr produce, SymbolTable &symbol_table, + GraphDbAccessor &db_accessor) { + ResultStreamFaker stream; + Frame frame(symbol_table.max_position()); + + // top level node in the operator tree is a produce (return) + // so stream out results + + // generate header + std::vector header; + for (auto named_expression : produce->named_expressions()) + header.push_back(named_expression->name_); + stream.Header(header); + + // collect the symbols from the return clause + std::vector symbols; + for (auto named_expression : produce->named_expressions()) + symbols.emplace_back(symbol_table[*named_expression]); + + // stream out results + auto cursor = produce->MakeCursor(db_accessor); + while (cursor->Pull(frame, symbol_table)) { + std::vector values; + for (auto &symbol : symbols) values.emplace_back(frame[symbol]); + stream.Result(values); + } + + stream.Summary({{std::string("type"), TypedValue("r")}}); + + return stream; +} + +int PullAll(std::shared_ptr logical_op, GraphDbAccessor &db, + SymbolTable symbol_table) { + Frame frame(symbol_table.max_position()); + auto cursor = logical_op->MakeCursor(db); + int count = 0; + while (cursor->Pull(frame, symbol_table)) count++; + return count; +} + +template +auto MakeProduce(std::shared_ptr input, + TNamedExpressions... named_expressions) { + return std::make_shared( + input, std::vector{named_expressions...}); +} + +struct ScanAllTuple { + NodeAtom *node_; + std::shared_ptr op_; + Symbol sym_; +}; + +/** + * Creates and returns a tuple of stuff for a scan-all starting + * from the node with the given name. + * + * Returns (node_atom, scan_all_logical_op, symbol). + */ +ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table, + const std::string &identifier, + std::shared_ptr input = {nullptr}) { + auto node = NODE(identifier); + auto logical_op = std::make_shared(node, input); + auto symbol = symbol_table.CreateSymbol(identifier); + symbol_table[*node->identifier_] = symbol; + // return std::make_tuple(node, logical_op, symbol); + return ScanAllTuple{node, logical_op, symbol}; +} + +struct ExpandTuple { + EdgeAtom *edge_; + Symbol edge_sym_; + NodeAtom *node_; + Symbol node_sym_; + std::shared_ptr op_; +}; + +ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table, + std::shared_ptr input, + Symbol input_symbol, const std::string &edge_identifier, + EdgeAtom::Direction direction, bool edge_cycle, + const std::string &node_identifier, bool node_cycle) { + auto edge = EDGE(edge_identifier, direction); + auto edge_sym = symbol_table.CreateSymbol(edge_identifier); + symbol_table[*edge->identifier_] = edge_sym; + + auto node = NODE(node_identifier); + auto node_sym = symbol_table.CreateSymbol(node_identifier); + symbol_table[*node->identifier_] = node_sym; + + auto op = std::make_shared(node, edge, input, input_symbol, + node_cycle, edge_cycle); + + return ExpandTuple{edge, edge_sym, node, node_sym, op}; +} + +template +auto CountIterable(TIterable iterable) { + return std::distance(iterable.begin(), iterable.end()); +} diff --git a/tests/unit/interpreter.cpp b/tests/unit/query_plan_create_set_remove_delete.cpp similarity index 50% rename from tests/unit/interpreter.cpp rename to tests/unit/query_plan_create_set_remove_delete.cpp index 9fc5a54b8..af49d66ea 100644 --- a/tests/unit/interpreter.cpp +++ b/tests/unit/query_plan_create_set_remove_delete.cpp @@ -15,300 +15,14 @@ #include "query/context.hpp" #include "query/exceptions.hpp" #include "query/frontend/interpret/interpret.hpp" -#include "query/frontend/logical/planner.hpp" +#include "query/frontend/logical/operator.hpp" -#include "query_common.hpp" +#include "query_plan_common.hpp" using namespace query; using namespace query::plan; -/** - * Helper function that collects all the results from the given - * Produce into a ResultStreamFaker and returns that object. - * - * @param produce - * @param symbol_table - * @param db_accessor - * @return - */ -auto CollectProduce(std::shared_ptr produce, SymbolTable &symbol_table, - GraphDbAccessor &db_accessor) { - ResultStreamFaker stream; - Frame frame(symbol_table.max_position()); - - // top level node in the operator tree is a produce (return) - // so stream out results - - // generate header - std::vector header; - for (auto named_expression : produce->named_expressions()) - header.push_back(named_expression->name_); - stream.Header(header); - - // collect the symbols from the return clause - std::vector symbols; - for (auto named_expression : produce->named_expressions()) - symbols.emplace_back(symbol_table[*named_expression]); - - // stream out results - auto cursor = produce->MakeCursor(db_accessor); - while (cursor->Pull(frame, symbol_table)) { - std::vector values; - for (auto &symbol : symbols) values.emplace_back(frame[symbol]); - stream.Result(values); - } - - stream.Summary({{std::string("type"), TypedValue("r")}}); - - return stream; -} - -int PullAll(std::shared_ptr logical_op, GraphDbAccessor &db, - SymbolTable symbol_table) { - Frame frame(symbol_table.max_position()); - auto cursor = logical_op->MakeCursor(db); - int count = 0; - while (cursor->Pull(frame, symbol_table)) count++; - return count; -} - -template -auto MakeProduce(std::shared_ptr input, - TNamedExpressions... named_expressions) { - return std::make_shared( - input, std::vector{named_expressions...}); -} - -struct ScanAllTuple { - NodeAtom *node_; - std::shared_ptr op_; - Symbol sym_; -}; - -/** - * Creates and returns a tuple of stuff for a scan-all starting - * from the node with the given name. - * - * Returns (node_atom, scan_all_logical_op, symbol). - */ -ScanAllTuple MakeScanAll(AstTreeStorage &storage, SymbolTable &symbol_table, - const std::string &identifier, - std::shared_ptr input = {nullptr}) { - auto node = NODE(identifier); - auto logical_op = std::make_shared(node, input); - auto symbol = symbol_table.CreateSymbol(identifier); - symbol_table[*node->identifier_] = symbol; - // return std::make_tuple(node, logical_op, symbol); - return ScanAllTuple{node, logical_op, symbol}; -} - -struct ExpandTuple { - EdgeAtom *edge_; - Symbol edge_sym_; - NodeAtom *node_; - Symbol node_sym_; - std::shared_ptr op_; -}; - -ExpandTuple MakeExpand(AstTreeStorage &storage, SymbolTable &symbol_table, - std::shared_ptr input, - Symbol input_symbol, const std::string &edge_identifier, - EdgeAtom::Direction direction, bool edge_cycle, - const std::string &node_identifier, bool node_cycle) { - auto edge = EDGE(edge_identifier, direction); - auto edge_sym = symbol_table.CreateSymbol(edge_identifier); - symbol_table[*edge->identifier_] = edge_sym; - - auto node = NODE(node_identifier); - auto node_sym = symbol_table.CreateSymbol(node_identifier); - symbol_table[*node->identifier_] = node_sym; - - auto op = std::make_shared(node, edge, input, input_symbol, - node_cycle, edge_cycle); - - return ExpandTuple{edge, edge_sym, node, node_sym, op}; -} - -template -auto CountIterable(TIterable iterable) { - return std::distance(iterable.begin(), iterable.end()); -} - -/* - * Actual tests start here. - */ - -TEST(Interpreter, MatchReturn) { - Dbms dbms; - auto dba = dbms.active(); - - // add a few nodes to the database - dba->insert_vertex(); - dba->insert_vertex(); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto scan_all = MakeScanAll(storage, symbol_table, "n"); - auto output = NEXPR("n", IDENT("n")); - auto produce = MakeProduce(scan_all.op_, output); - symbol_table[*output->expression_] = scan_all.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), 2); -} - -TEST(Interpreter, MatchReturnCartesian) { - Dbms dbms; - auto dba = dbms.active(); - - dba->insert_vertex().add_label(dba->label("l1")); - dba->insert_vertex().add_label(dba->label("l2")); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto n = MakeScanAll(storage, symbol_table, "n"); - auto m = MakeScanAll(storage, symbol_table, "m", n.op_); - auto return_n = NEXPR("n", IDENT("n")); - symbol_table[*return_n->expression_] = n.sym_; - symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1"); - auto return_m = NEXPR("m", IDENT("m")); - symbol_table[*return_m->expression_] = m.sym_; - symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2"); - auto produce = MakeProduce(m.op_, return_n, return_m); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - auto result_data = result.GetResults(); - EXPECT_EQ(result_data.size(), 4); - // ensure the result ordering is OK: - // "n" from the results is the same for the first two rows, while "m" isn't - EXPECT_EQ(result_data[0][0].Value(), - result_data[1][0].Value()); - EXPECT_NE(result_data[0][1].Value(), - result_data[1][1].Value()); -} - -TEST(Interpreter, StandaloneReturn) { - Dbms dbms; - auto dba = dbms.active(); - - // add a few nodes to the database - dba->insert_vertex(); - dba->insert_vertex(); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto output = NEXPR("n", LITERAL(42)); - auto produce = MakeProduce(std::shared_ptr(nullptr), output); - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), 1); - EXPECT_EQ(result.GetResults()[0].size(), 1); - EXPECT_EQ(result.GetResults()[0][0].Value(), 42); -} - -TEST(Interpreter, NodeFilterLabelsAndProperties) { - Dbms dbms; - auto dba = dbms.active(); - - // add a few nodes to the database - GraphDbTypes::Label label = dba->label("Label"); - GraphDbTypes::Property property = dba->property("Property"); - auto v1 = dba->insert_vertex(); - auto v2 = dba->insert_vertex(); - auto v3 = dba->insert_vertex(); - auto v4 = dba->insert_vertex(); - auto v5 = dba->insert_vertex(); - dba->insert_vertex(); - // test all combination of (label | no_label) * (no_prop | wrong_prop | - // right_prop) - // only v1 will have the right labels - v1.add_label(label); - v2.add_label(label); - v3.add_label(label); - v1.PropsSet(property, 42); - v2.PropsSet(property, 1); - v4.PropsSet(property, 42); - v5.PropsSet(property, 1); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - // make a scan all - auto n = MakeScanAll(storage, symbol_table, "n"); - n.node_->labels_.emplace_back(label); - n.node_->properties_[property] = LITERAL(42); - - // node filtering - auto node_filter = std::make_shared(n.op_, n.sym_, n.node_); - - // make a named expression and a produce - auto output = NEXPR("x", IDENT("n")); - symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - auto produce = MakeProduce(node_filter, output); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), 1); -} - -TEST(Interpreter, NodeFilterMultipleLabels) { - Dbms dbms; - auto dba = dbms.active(); - - // add a few nodes to the database - GraphDbTypes::Label label1 = dba->label("label1"); - GraphDbTypes::Label label2 = dba->label("label2"); - GraphDbTypes::Label label3 = dba->label("label3"); - // the test will look for nodes that have label1 and label2 - dba->insert_vertex(); // NOT accepted - dba->insert_vertex().add_label(label1); // NOT accepted - dba->insert_vertex().add_label(label2); // NOT accepted - dba->insert_vertex().add_label(label3); // NOT accepted - auto v1 = dba->insert_vertex(); // YES accepted - v1.add_label(label1); - v1.add_label(label2); - auto v2 = dba->insert_vertex(); // NOT accepted - v2.add_label(label1); - v2.add_label(label3); - auto v3 = dba->insert_vertex(); // YES accepted - v3.add_label(label1); - v3.add_label(label2); - v3.add_label(label3); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - // make a scan all - auto n = MakeScanAll(storage, symbol_table, "n"); - n.node_->labels_.emplace_back(label1); - n.node_->labels_.emplace_back(label2); - - // node filtering - auto node_filter = std::make_shared(n.op_, n.sym_, n.node_); - - // make a named expression and a produce - auto output = NEXPR("n", IDENT("n")); - auto produce = MakeProduce(node_filter, output); - - // fill up the symbol table - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - symbol_table[*output->expression_] = n.sym_; - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), 2); -} - -TEST(Interpreter, CreateNodeWithAttributes) { +TEST(QueryPlan, CreateNodeWithAttributes) { Dbms dbms; auto dba = dbms.active(); @@ -341,7 +55,7 @@ TEST(Interpreter, CreateNodeWithAttributes) { EXPECT_EQ(vertex_count, 1); } -TEST(Interpreter, CreateReturn) { +TEST(QueryPlan, CreateReturn) { // test CREATE (n:Person {age: 42}) RETURN n, n.age Dbms dbms; auto dba = dbms.active(); @@ -384,7 +98,7 @@ TEST(Interpreter, CreateReturn) { EXPECT_EQ(1, CountIterable(dba->vertices())); } -TEST(Interpreter, CreateExpand) { +TEST(QueryPlan, CreateExpand) { Dbms dbms; auto dba = dbms.active(); @@ -457,7 +171,7 @@ TEST(Interpreter, CreateExpand) { } } -TEST(Interpreter, MatchCreateNode) { +TEST(QueryPlan, MatchCreateNode) { Dbms dbms; auto dba = dbms.active(); @@ -484,7 +198,7 @@ TEST(Interpreter, MatchCreateNode) { EXPECT_EQ(CountIterable(dba->vertices()), 6); } -TEST(Interpreter, MatchCreateExpand) { +TEST(QueryPlan, MatchCreateExpand) { Dbms dbms; auto dba = dbms.active(); @@ -535,223 +249,7 @@ TEST(Interpreter, MatchCreateExpand) { test_create_path(true, 0, 6); } -TEST(Interpreter, Expand) { - Dbms dbms; - auto dba = dbms.active(); - - // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) - auto v1 = dba->insert_vertex(); - v1.add_label((GraphDbTypes::Label)1); - auto v2 = dba->insert_vertex(); - v2.add_label((GraphDbTypes::Label)2); - auto v3 = dba->insert_vertex(); - v3.add_label((GraphDbTypes::Label)3); - auto edge_type = dba->edge_type("Edge"); - dba->insert_edge(v1, v2, edge_type); - dba->insert_edge(v1, v3, edge_type); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto test_expand = [&](EdgeAtom::Direction direction, - int expected_result_count) { - auto n = MakeScanAll(storage, symbol_table, "n"); - auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction, - false, "m", false); - - // make a named expression and a produce - auto output = NEXPR("m", IDENT("m")); - symbol_table[*output->expression_] = r_m.node_sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - auto produce = MakeProduce(r_m.op_, output); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), expected_result_count); - }; - - test_expand(EdgeAtom::Direction::RIGHT, 2); - test_expand(EdgeAtom::Direction::LEFT, 2); - test_expand(EdgeAtom::Direction::BOTH, 4); -} - -TEST(Interpreter, ExpandNodeCycle) { - Dbms dbms; - auto dba = dbms.active(); - - // make a graph (v1)->(v2) that - // has a recursive edge (v1)->(v1) - auto v1 = dba->insert_vertex(); - auto v2 = dba->insert_vertex(); - auto edge_type = dba->edge_type("Edge"); - dba->insert_edge(v1, v1, edge_type); - dba->insert_edge(v1, v2, edge_type); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto test_cycle = [&](bool with_cycle, int expected_result_count) { - auto n = MakeScanAll(storage, symbol_table, "n"); - auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", - EdgeAtom::Direction::RIGHT, false, "n", with_cycle); - if (with_cycle) - symbol_table[*r_n.node_->identifier_] = - symbol_table[*n.node_->identifier_]; - - // make a named expression and a produce - auto output = NEXPR("n", IDENT("n")); - symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - auto produce = MakeProduce(r_n.op_, output); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), expected_result_count); - }; - - test_cycle(true, 1); - test_cycle(false, 2); -} - -TEST(Interpreter, ExpandEdgeCycle) { - Dbms dbms; - auto dba = dbms.active(); - - // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) - auto v1 = dba->insert_vertex(); - v1.add_label((GraphDbTypes::Label)1); - auto v2 = dba->insert_vertex(); - v2.add_label((GraphDbTypes::Label)2); - auto v3 = dba->insert_vertex(); - v3.add_label((GraphDbTypes::Label)3); - auto edge_type = dba->edge_type("Edge"); - dba->insert_edge(v1, v2, edge_type); - dba->insert_edge(v1, v3, edge_type); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto test_cycle = [&](bool with_cycle, int expected_result_count) { - auto i = MakeScanAll(storage, symbol_table, "i"); - auto r_j = MakeExpand(storage, symbol_table, i.op_, i.sym_, "r", - EdgeAtom::Direction::BOTH, false, "j", false); - auto r_k = MakeExpand(storage, symbol_table, r_j.op_, r_j.node_sym_, "r", - EdgeAtom::Direction::BOTH, with_cycle, "k", false); - if (with_cycle) - symbol_table[*r_k.edge_->identifier_] = - symbol_table[*r_j.edge_->identifier_]; - - // make a named expression and a produce - auto output = NEXPR("r", IDENT("r")); - symbol_table[*output->expression_] = r_j.edge_sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - auto produce = MakeProduce(r_k.op_, output); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), expected_result_count); - - }; - - test_cycle(true, 4); - test_cycle(false, 6); -} - -TEST(Interpreter, EdgeFilter) { - Dbms dbms; - auto dba = dbms.active(); - - // make an N-star expanding from (v1) - // where only one edge will qualify - // and there are all combinations of - // (edge_type yes|no) * (property yes|absent|no) - std::vector edge_types; - for (int j = 0; j < 2; ++j) - edge_types.push_back(dba->edge_type("et" + std::to_string(j))); - std::vector vertices; - for (int i = 0; i < 7; ++i) vertices.push_back(dba->insert_vertex()); - GraphDbTypes::Property prop = dba->property("prop"); - std::vector edges; - for (int i = 0; i < 6; ++i) { - edges.push_back( - dba->insert_edge(vertices[0], vertices[i + 1], edge_types[i % 2])); - switch (i % 3) { - case 0: - edges.back().PropsSet(prop, 42); - break; - case 1: - edges.back().PropsSet(prop, 100); - break; - default: - break; - } - } - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - // define an operator tree for query - // MATCH (n)-[r]->(m) RETURN m - - auto n = MakeScanAll(storage, symbol_table, "n"); - auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", - EdgeAtom::Direction::RIGHT, false, "m", false); - r_m.edge_->edge_types_.push_back(edge_types[0]); - r_m.edge_->properties_[prop] = LITERAL(42); - auto edge_filter = - std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); - - // make a named expression and a produce - auto output = NEXPR("m", IDENT("m")); - symbol_table[*output->expression_] = r_m.node_sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - auto produce = MakeProduce(edge_filter, output); - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), 1); -} - -TEST(Interpreter, EdgeFilterMultipleTypes) { - Dbms dbms; - auto dba = dbms.active(); - - auto v1 = dba->insert_vertex(); - auto v2 = dba->insert_vertex(); - auto type_1 = dba->edge_type("type_1"); - auto type_2 = dba->edge_type("type_2"); - auto type_3 = dba->edge_type("type_3"); - dba->insert_edge(v1, v2, type_1); - dba->insert_edge(v1, v2, type_2); - dba->insert_edge(v1, v2, type_3); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - // make a scan all - auto n = MakeScanAll(storage, symbol_table, "n"); - auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", - EdgeAtom::Direction::RIGHT, false, "m", false); - // add a property filter - auto edge_filter = - std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); - r_m.edge_->edge_types_.push_back(type_1); - r_m.edge_->edge_types_.push_back(type_2); - - // make a named expression and a produce - auto output = NEXPR("m", IDENT("m")); - auto produce = MakeProduce(edge_filter, output); - - // fill up the symbol table - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - symbol_table[*output->expression_] = r_m.node_sym_; - - ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); - EXPECT_EQ(result.GetResults().size(), 2); -} - -TEST(Interpreter, Delete) { +TEST(QueryPlan, Delete) { Dbms dbms; auto dba = dbms.active(); @@ -826,7 +324,7 @@ TEST(Interpreter, Delete) { } } -TEST(Interpreter, DeleteReturn) { +TEST(QueryPlan, DeleteReturn) { Dbms dbms; auto dba = dbms.active(); @@ -864,36 +362,7 @@ TEST(Interpreter, DeleteReturn) { EXPECT_EQ(0, CountIterable(dba->vertices())); } -TEST(Interpreter, Filter) { - Dbms dbms; - auto dba = dbms.active(); - - // add a 6 nodes with property 'prop', 2 have true as value - GraphDbTypes::Property property = dba->property("Property"); - for (int i = 0; i < 6; ++i) - dba->insert_vertex().PropsSet(property, i % 3 == 0); - dba->insert_vertex(); // prop not set, gives NULL - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto n = MakeScanAll(storage, symbol_table, "n"); - auto e = - storage.Create(storage.Create("n"), property); - symbol_table[*e->expression_] = n.sym_; - auto f = std::make_shared(n.op_, e); - - auto output = - storage.Create("x", storage.Create("n")); - symbol_table[*output->expression_] = n.sym_; - symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); - auto produce = MakeProduce(f, output); - - EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2); -} - -TEST(Interpreter, SetProperty) { +TEST(QueryPlan, SetProperty) { Dbms dbms; auto dba = dbms.active(); @@ -943,7 +412,7 @@ TEST(Interpreter, SetProperty) { } } -TEST(Interpreter, SetProperties) { +TEST(QueryPlan, SetProperties) { auto test_set_properties = [](bool update) { Dbms dbms; auto dba = dbms.active(); @@ -1013,7 +482,7 @@ TEST(Interpreter, SetProperties) { test_set_properties(false); } -TEST(Interpreter, SetLabels) { +TEST(QueryPlan, SetLabels) { Dbms dbms; auto dba = dbms.active(); @@ -1040,7 +509,7 @@ TEST(Interpreter, SetLabels) { } } -TEST(Interpreter, RemoveProperty) { +TEST(QueryPlan, RemoveProperty) { Dbms dbms; auto dba = dbms.active(); @@ -1092,7 +561,7 @@ TEST(Interpreter, RemoveProperty) { } } -TEST(Interpreter, RemoveLabels) { +TEST(QueryPlan, RemoveLabels) { Dbms dbms; auto dba = dbms.active(); @@ -1124,7 +593,7 @@ TEST(Interpreter, RemoveLabels) { } } -TEST(Interpreter, NodeFilterSet) { +TEST(QueryPlan, NodeFilterSet) { Dbms dbms; auto dba = dbms.active(); // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. @@ -1162,7 +631,7 @@ TEST(Interpreter, NodeFilterSet) { EXPECT_TRUE(prop_eq.Value()); } -TEST(Interpreter, FilterRemove) { +TEST(QueryPlan, FilterRemove) { Dbms dbms; auto dba = dbms.active(); // Create a graph such that (v1 {prop: 42}) is connected to v2 and v3. @@ -1198,7 +667,7 @@ TEST(Interpreter, FilterRemove) { EXPECT_EQ(v1.PropsAt(prop).type(), PropertyValue::Type::Null); } -TEST(Interpreter, SetRemove) { +TEST(QueryPlan, SetRemove) { Dbms dbms; auto dba = dbms.active(); auto v = dba->insert_vertex(); @@ -1222,132 +691,3 @@ TEST(Interpreter, SetRemove) { EXPECT_FALSE(v.has_label(label1)); EXPECT_FALSE(v.has_label(label2)); } - -TEST(Interpreter, ExpandUniquenessFilter) { - Dbms dbms; - auto dba = dbms.active(); - - // make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1) - auto v1 = dba->insert_vertex(); - auto v2 = dba->insert_vertex(); - auto edge_type = dba->edge_type("edge_type"); - dba->insert_edge(v1, v2, edge_type); - dba->insert_edge(v1, v1, edge_type); - dba->advance_command(); - - auto check_expand_results = [&](bool vertex_uniqueness, - bool edge_uniqueness) { - AstTreeStorage storage; - SymbolTable symbol_table; - - auto n1 = MakeScanAll(storage, symbol_table, "n1"); - auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1", - EdgeAtom::Direction::RIGHT, false, "n2", false); - std::shared_ptr last_op = r1_n2.op_; - if (vertex_uniqueness) - last_op = std::make_shared>( - last_op, r1_n2.node_sym_, std::vector{n1.sym_}); - auto r2_n3 = - MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2", - EdgeAtom::Direction::RIGHT, false, "n3", false); - last_op = r2_n3.op_; - if (edge_uniqueness) - last_op = std::make_shared>( - last_op, r2_n3.edge_sym_, std::vector{r1_n2.edge_sym_}); - if (vertex_uniqueness) - last_op = std::make_shared>( - last_op, r2_n3.node_sym_, - std::vector{n1.sym_, r1_n2.node_sym_}); - - return PullAll(last_op, *dba, symbol_table); - }; - - EXPECT_EQ(2, check_expand_results(false, false)); - EXPECT_EQ(0, check_expand_results(true, false)); - EXPECT_EQ(1, check_expand_results(false, true)); -} - -TEST(Interpreter, Accumulate) { - // simulate the following two query execution on an empty db - // CREATE ({x:0})-[:T]->({x:0}) - // MATCH (n)--(m) SET n.x = n.x + 1, m.x = m.x + 1 RETURN n.x, m.x - // without accumulation we expected results to be [[1, 1], [2, 2]] - // with accumulation we expect them to be [[2, 2], [2, 2]] - - auto check = [&](bool accumulate) { - Dbms dbms; - auto dba = dbms.active(); - auto prop = dba->property("x"); - - auto v1 = dba->insert_vertex(); - v1.PropsSet(prop, 0); - auto v2 = dba->insert_vertex(); - v2.PropsSet(prop, 0); - dba->insert_edge(v1, v2, dba->edge_type("T")); - dba->advance_command(); - - AstTreeStorage storage; - SymbolTable symbol_table; - - auto n = MakeScanAll(storage, symbol_table, "n"); - auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", - EdgeAtom::Direction::BOTH, false, "m", false); - - auto one = LITERAL(1); - auto n_p = PROPERTY_LOOKUP("n", prop); - symbol_table[*n_p->expression_] = n.sym_; - auto set_n_p = - std::make_shared(r_m.op_, n_p, ADD(n_p, one)); - auto m_p = PROPERTY_LOOKUP("m", prop); - symbol_table[*m_p->expression_] = r_m.node_sym_; - auto set_m_p = - std::make_shared(set_n_p, m_p, ADD(m_p, one)); - - std::shared_ptr last_op = set_m_p; - if (accumulate) { - last_op = std::make_shared( - last_op, std::vector{n.sym_, r_m.node_sym_}); - } - - auto n_p_ne = NEXPR("n.p", n_p); - symbol_table[*n_p_ne] = symbol_table.CreateSymbol("n_p_ne"); - auto m_p_ne = NEXPR("m.p", m_p); - symbol_table[*m_p_ne] = symbol_table.CreateSymbol("m_p_ne"); - auto produce = MakeProduce(last_op, n_p_ne, m_p_ne); - ResultStreamFaker results = CollectProduce(produce, symbol_table, *dba); - std::vector results_data; - for (const auto &row : results.GetResults()) - for (const auto &column : row) - results_data.emplace_back(column.Value()); - if (accumulate) - EXPECT_THAT(results_data, testing::ElementsAre(2, 2, 2, 2)); - else - EXPECT_THAT(results_data, testing::ElementsAre(1, 1, 2, 2)); - }; - - check(false); - check(true); -} - -TEST(Interpreter, AccumulateAdvance) { - // we simulate 'CREATE (n) WITH n AS n MATCH (m) RETURN m' - // to get correct results we need to advance the command - - auto check = [&](bool advance) { - Dbms dbms; - auto dba = dbms.active(); - AstTreeStorage storage; - SymbolTable symbol_table; - - auto node = NODE("n"); - auto sym_n = symbol_table.CreateSymbol("n"); - symbol_table[*node->identifier_] = sym_n; - auto create = std::make_shared(node, nullptr); - auto accumulate = std::make_shared( - create, std::vector{sym_n}, advance); - auto match = MakeScanAll(storage, symbol_table, "m", accumulate); - EXPECT_EQ(advance ? 1 : 0, PullAll(match.op_, *dba, symbol_table)); - }; - check(false); - check(true); -} diff --git a/tests/unit/query_plan_match_filter_return.cpp b/tests/unit/query_plan_match_filter_return.cpp new file mode 100644 index 000000000..8510bab89 --- /dev/null +++ b/tests/unit/query_plan_match_filter_return.cpp @@ -0,0 +1,482 @@ +// +// Copyright 2017 Memgraph +// Created by Florijan Stamenkovic on 14.03.17. +// + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "communication/result_stream_faker.hpp" +#include "dbms/dbms.hpp" +#include "query/context.hpp" +#include "query/exceptions.hpp" +#include "query/frontend/interpret/interpret.hpp" +#include "query/frontend/logical/operator.hpp" + +#include "query_plan_common.hpp" + +using namespace query; +using namespace query::plan; + +TEST(QueryPlan, MatchReturn) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + dba->insert_vertex(); + dba->insert_vertex(); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto scan_all = MakeScanAll(storage, symbol_table, "n"); + auto output = NEXPR("n", IDENT("n")); + auto produce = MakeProduce(scan_all.op_, output); + symbol_table[*output->expression_] = scan_all.sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 2); +} + +TEST(QueryPlan, MatchReturnCartesian) { + Dbms dbms; + auto dba = dbms.active(); + + dba->insert_vertex().add_label(dba->label("l1")); + dba->insert_vertex().add_label(dba->label("l2")); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto m = MakeScanAll(storage, symbol_table, "m", n.op_); + auto return_n = NEXPR("n", IDENT("n")); + symbol_table[*return_n->expression_] = n.sym_; + symbol_table[*return_n] = symbol_table.CreateSymbol("named_expression_1"); + auto return_m = NEXPR("m", IDENT("m")); + symbol_table[*return_m->expression_] = m.sym_; + symbol_table[*return_m] = symbol_table.CreateSymbol("named_expression_2"); + auto produce = MakeProduce(m.op_, return_n, return_m); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + auto result_data = result.GetResults(); + EXPECT_EQ(result_data.size(), 4); + // ensure the result ordering is OK: + // "n" from the results is the same for the first two rows, while "m" isn't + EXPECT_EQ(result_data[0][0].Value(), + result_data[1][0].Value()); + EXPECT_NE(result_data[0][1].Value(), + result_data[1][1].Value()); +} + +TEST(QueryPlan, StandaloneReturn) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + dba->insert_vertex(); + dba->insert_vertex(); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto output = NEXPR("n", LITERAL(42)); + auto produce = MakeProduce(std::shared_ptr(nullptr), output); + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 1); + EXPECT_EQ(result.GetResults()[0].size(), 1); + EXPECT_EQ(result.GetResults()[0][0].Value(), 42); +} + +TEST(QueryPlan, NodeFilterLabelsAndProperties) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + GraphDbTypes::Label label = dba->label("Label"); + GraphDbTypes::Property property = dba->property("Property"); + auto v1 = dba->insert_vertex(); + auto v2 = dba->insert_vertex(); + auto v3 = dba->insert_vertex(); + auto v4 = dba->insert_vertex(); + auto v5 = dba->insert_vertex(); + dba->insert_vertex(); + // test all combination of (label | no_label) * (no_prop | wrong_prop | + // right_prop) + // only v1 will have the right labels + v1.add_label(label); + v2.add_label(label); + v3.add_label(label); + v1.PropsSet(property, 42); + v2.PropsSet(property, 1); + v4.PropsSet(property, 42); + v5.PropsSet(property, 1); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + n.node_->labels_.emplace_back(label); + n.node_->properties_[property] = LITERAL(42); + + // node filtering + auto node_filter = std::make_shared(n.op_, n.sym_, n.node_); + + // make a named expression and a produce + auto output = NEXPR("x", IDENT("n")); + symbol_table[*output->expression_] = n.sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + auto produce = MakeProduce(node_filter, output); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 1); +} + +TEST(QueryPlan, NodeFilterMultipleLabels) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + GraphDbTypes::Label label1 = dba->label("label1"); + GraphDbTypes::Label label2 = dba->label("label2"); + GraphDbTypes::Label label3 = dba->label("label3"); + // the test will look for nodes that have label1 and label2 + dba->insert_vertex(); // NOT accepted + dba->insert_vertex().add_label(label1); // NOT accepted + dba->insert_vertex().add_label(label2); // NOT accepted + dba->insert_vertex().add_label(label3); // NOT accepted + auto v1 = dba->insert_vertex(); // YES accepted + v1.add_label(label1); + v1.add_label(label2); + auto v2 = dba->insert_vertex(); // NOT accepted + v2.add_label(label1); + v2.add_label(label3); + auto v3 = dba->insert_vertex(); // YES accepted + v3.add_label(label1); + v3.add_label(label2); + v3.add_label(label3); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + n.node_->labels_.emplace_back(label1); + n.node_->labels_.emplace_back(label2); + + // node filtering + auto node_filter = std::make_shared(n.op_, n.sym_, n.node_); + + // make a named expression and a produce + auto output = NEXPR("n", IDENT("n")); + auto produce = MakeProduce(node_filter, output); + + // fill up the symbol table + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output->expression_] = n.sym_; + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 2); +} + +TEST(QueryPlan, Expand) { + Dbms dbms; + auto dba = dbms.active(); + + // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) + auto v1 = dba->insert_vertex(); + v1.add_label((GraphDbTypes::Label)1); + auto v2 = dba->insert_vertex(); + v2.add_label((GraphDbTypes::Label)2); + auto v3 = dba->insert_vertex(); + v3.add_label((GraphDbTypes::Label)3); + auto edge_type = dba->edge_type("Edge"); + dba->insert_edge(v1, v2, edge_type); + dba->insert_edge(v1, v3, edge_type); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto test_expand = [&](EdgeAtom::Direction direction, + int expected_result_count) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", direction, + false, "m", false); + + // make a named expression and a produce + auto output = NEXPR("m", IDENT("m")); + symbol_table[*output->expression_] = r_m.node_sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + auto produce = MakeProduce(r_m.op_, output); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), expected_result_count); + }; + + test_expand(EdgeAtom::Direction::RIGHT, 2); + test_expand(EdgeAtom::Direction::LEFT, 2); + test_expand(EdgeAtom::Direction::BOTH, 4); +} + +TEST(QueryPlan, ExpandNodeCycle) { + Dbms dbms; + auto dba = dbms.active(); + + // make a graph (v1)->(v2) that + // has a recursive edge (v1)->(v1) + auto v1 = dba->insert_vertex(); + auto v2 = dba->insert_vertex(); + auto edge_type = dba->edge_type("Edge"); + dba->insert_edge(v1, v1, edge_type); + dba->insert_edge(v1, v2, edge_type); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto test_cycle = [&](bool with_cycle, int expected_result_count) { + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_n = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", + EdgeAtom::Direction::RIGHT, false, "n", with_cycle); + if (with_cycle) + symbol_table[*r_n.node_->identifier_] = + symbol_table[*n.node_->identifier_]; + + // make a named expression and a produce + auto output = NEXPR("n", IDENT("n")); + symbol_table[*output->expression_] = n.sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + auto produce = MakeProduce(r_n.op_, output); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), expected_result_count); + }; + + test_cycle(true, 1); + test_cycle(false, 2); +} + +TEST(QueryPlan, ExpandEdgeCycle) { + Dbms dbms; + auto dba = dbms.active(); + + // make a V-graph (v3)<-[r2]-(v1)-[r1]->(v2) + auto v1 = dba->insert_vertex(); + v1.add_label((GraphDbTypes::Label)1); + auto v2 = dba->insert_vertex(); + v2.add_label((GraphDbTypes::Label)2); + auto v3 = dba->insert_vertex(); + v3.add_label((GraphDbTypes::Label)3); + auto edge_type = dba->edge_type("Edge"); + dba->insert_edge(v1, v2, edge_type); + dba->insert_edge(v1, v3, edge_type); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto test_cycle = [&](bool with_cycle, int expected_result_count) { + auto i = MakeScanAll(storage, symbol_table, "i"); + auto r_j = MakeExpand(storage, symbol_table, i.op_, i.sym_, "r", + EdgeAtom::Direction::BOTH, false, "j", false); + auto r_k = MakeExpand(storage, symbol_table, r_j.op_, r_j.node_sym_, "r", + EdgeAtom::Direction::BOTH, with_cycle, "k", false); + if (with_cycle) + symbol_table[*r_k.edge_->identifier_] = + symbol_table[*r_j.edge_->identifier_]; + + // make a named expression and a produce + auto output = NEXPR("r", IDENT("r")); + symbol_table[*output->expression_] = r_j.edge_sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + auto produce = MakeProduce(r_k.op_, output); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), expected_result_count); + + }; + + test_cycle(true, 4); + test_cycle(false, 6); +} + +TEST(QueryPlan, EdgeFilter) { + Dbms dbms; + auto dba = dbms.active(); + + // make an N-star expanding from (v1) + // where only one edge will qualify + // and there are all combinations of + // (edge_type yes|no) * (property yes|absent|no) + std::vector edge_types; + for (int j = 0; j < 2; ++j) + edge_types.push_back(dba->edge_type("et" + std::to_string(j))); + std::vector vertices; + for (int i = 0; i < 7; ++i) vertices.push_back(dba->insert_vertex()); + GraphDbTypes::Property prop = dba->property("prop"); + std::vector edges; + for (int i = 0; i < 6; ++i) { + edges.push_back( + dba->insert_edge(vertices[0], vertices[i + 1], edge_types[i % 2])); + switch (i % 3) { + case 0: + edges.back().PropsSet(prop, 42); + break; + case 1: + edges.back().PropsSet(prop, 100); + break; + default: + break; + } + } + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // define an operator tree for query + // MATCH (n)-[r]->(m) RETURN m + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", + EdgeAtom::Direction::RIGHT, false, "m", false); + r_m.edge_->edge_types_.push_back(edge_types[0]); + r_m.edge_->properties_[prop] = LITERAL(42); + auto edge_filter = + std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); + + // make a named expression and a produce + auto output = NEXPR("m", IDENT("m")); + symbol_table[*output->expression_] = r_m.node_sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + auto produce = MakeProduce(edge_filter, output); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 1); +} + +TEST(QueryPlan, EdgeFilterMultipleTypes) { + Dbms dbms; + auto dba = dbms.active(); + + auto v1 = dba->insert_vertex(); + auto v2 = dba->insert_vertex(); + auto type_1 = dba->edge_type("type_1"); + auto type_2 = dba->edge_type("type_2"); + auto type_3 = dba->edge_type("type_3"); + dba->insert_edge(v1, v2, type_1); + dba->insert_edge(v1, v2, type_2); + dba->insert_edge(v1, v2, type_3); + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + // make a scan all + auto n = MakeScanAll(storage, symbol_table, "n"); + auto r_m = MakeExpand(storage, symbol_table, n.op_, n.sym_, "r", + EdgeAtom::Direction::RIGHT, false, "m", false); + // add a property filter + auto edge_filter = + std::make_shared(r_m.op_, r_m.edge_sym_, r_m.edge_); + r_m.edge_->edge_types_.push_back(type_1); + r_m.edge_->edge_types_.push_back(type_2); + + // make a named expression and a produce + auto output = NEXPR("m", IDENT("m")); + auto produce = MakeProduce(edge_filter, output); + + // fill up the symbol table + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*output->expression_] = r_m.node_sym_; + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 2); +} + +TEST(QueryPlan, Filter) { + Dbms dbms; + auto dba = dbms.active(); + + // add a 6 nodes with property 'prop', 2 have true as value + GraphDbTypes::Property property = dba->property("Property"); + for (int i = 0; i < 6; ++i) + dba->insert_vertex().PropsSet(property, i % 3 == 0); + dba->insert_vertex(); // prop not set, gives NULL + dba->advance_command(); + + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n = MakeScanAll(storage, symbol_table, "n"); + auto e = + storage.Create(storage.Create("n"), property); + symbol_table[*e->expression_] = n.sym_; + auto f = std::make_shared(n.op_, e); + + auto output = + storage.Create("x", storage.Create("n")); + symbol_table[*output->expression_] = n.sym_; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + auto produce = MakeProduce(f, output); + + EXPECT_EQ(CollectProduce(produce, symbol_table, *dba).GetResults().size(), 2); +} + +TEST(QueryPlan, ExpandUniquenessFilter) { + Dbms dbms; + auto dba = dbms.active(); + + // make a graph that has (v1)->(v2) and a recursive edge (v1)->(v1) + auto v1 = dba->insert_vertex(); + auto v2 = dba->insert_vertex(); + auto edge_type = dba->edge_type("edge_type"); + dba->insert_edge(v1, v2, edge_type); + dba->insert_edge(v1, v1, edge_type); + dba->advance_command(); + + auto check_expand_results = [&](bool vertex_uniqueness, + bool edge_uniqueness) { + AstTreeStorage storage; + SymbolTable symbol_table; + + auto n1 = MakeScanAll(storage, symbol_table, "n1"); + auto r1_n2 = MakeExpand(storage, symbol_table, n1.op_, n1.sym_, "r1", + EdgeAtom::Direction::RIGHT, false, "n2", false); + std::shared_ptr last_op = r1_n2.op_; + if (vertex_uniqueness) + last_op = std::make_shared>( + last_op, r1_n2.node_sym_, std::vector{n1.sym_}); + auto r2_n3 = + MakeExpand(storage, symbol_table, last_op, r1_n2.node_sym_, "r2", + EdgeAtom::Direction::RIGHT, false, "n3", false); + last_op = r2_n3.op_; + if (edge_uniqueness) + last_op = std::make_shared>( + last_op, r2_n3.edge_sym_, std::vector{r1_n2.edge_sym_}); + if (vertex_uniqueness) + last_op = std::make_shared>( + last_op, r2_n3.node_sym_, + std::vector{n1.sym_, r1_n2.node_sym_}); + + return PullAll(last_op, *dba, symbol_table); + }; + + EXPECT_EQ(2, check_expand_results(false, false)); + EXPECT_EQ(0, check_expand_results(true, false)); + EXPECT_EQ(1, check_expand_results(false, true)); +}