From 6c7372b3c507ee4da9e09bd116f19d3f9b0bf389 Mon Sep 17 00:00:00 2001 From: florijan <florijan@memgraph.io> Date: Wed, 15 Mar 2017 15:49:19 +0100 Subject: [PATCH] Query - AST tests in progress Summary: Query compiler AST test in progress Logical operator testing. NodeFilter LogicalOperator added Reviewers: mislav.bradac, buda, teon.banek Reviewed By: buda, teon.banek Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D126 --- src/communication/result_stream_faker.hpp | 12 +- src/query/entry.hpp | 10 +- src/query/frontend/ast/ast.hpp | 49 +++-- src/query/frontend/ast/ast_visitor.hpp | 4 + src/query/frontend/interpret/interpret.hpp | 48 ++++- src/query/frontend/logical/operator.hpp | 86 ++++++-- tests/unit/interpreter.cpp | 228 +++++++++++++++++++++ 7 files changed, 388 insertions(+), 49 deletions(-) create mode 100644 tests/unit/interpreter.cpp diff --git a/src/communication/result_stream_faker.hpp b/src/communication/result_stream_faker.hpp index 8de4cf7b9..2bb709f80 100644 --- a/src/communication/result_stream_faker.hpp +++ b/src/communication/result_stream_faker.hpp @@ -13,20 +13,20 @@ class ResultStreamFaker { public: - void Header(std::vector<std::string> &&fields) { + void Header(const std::vector<std::string> &fields) { debug_assert(current_state_ == State::Start, "Headers can only be written in the beginning"); - header_ = std::forward(fields); + header_ = fields; current_state_ = State::WritingResults; } - void Result(std::vector<TypedValue> &&values) { + void Result(const std::vector<TypedValue> &values) { debug_assert(current_state_ == State::WritingResults, "Can't accept results before header nor after summary"); - results_.push_back(std::forward(values)); + results_.push_back(values); } - void Summary(std::map<std::string, TypedValue> &&summary) { + void Summary(const std::map<std::string, TypedValue> &summary) { debug_assert(current_state_ != State::Done, "Can only send a summary once"); - summary_ = std::forward(summary); + summary_ = summary; current_state_ = State::Done; } diff --git a/src/query/entry.hpp b/src/query/entry.hpp index 43bc9899b..8a2927793 100644 --- a/src/query/entry.hpp +++ b/src/query/entry.hpp @@ -6,7 +6,6 @@ #include "query/frontend/interpret/interpret.hpp" #include "query/frontend/logical/planner.hpp" #include "query/frontend/opencypher/parser.hpp" -#include "query/frontend/semantic/symbol_table.hpp" #include "query/frontend/semantic/symbol_generator.hpp" namespace query { @@ -50,19 +49,22 @@ class Engine { // AST -> high level tree HighLevelAstConversion low2high_tree; auto high_level_tree = low2high_tree.Apply(ctx, low_level_tree); + Execute(*high_level_tree, db_accessor, stream); + } + auto Execute(Query &query, GraphDbAccessor &db_accessor, Stream stream) { // symbol table fill SymbolTable symbol_table; SymbolGenerator symbol_generator(symbol_table); - high_level_tree->Accept(symbol_generator); + query.Accept(symbol_generator); // high level tree -> logical plan - auto logical_plan = MakeLogicalPlan(*high_level_tree); + auto logical_plan = MakeLogicalPlan(query); // generate frame based on symbol table max_position Frame frame(symbol_table.max_position()); - auto *produce = dynamic_cast<Produce*>(logical_plan.get()); + auto *produce = dynamic_cast<Produce *>(logical_plan.get()); if (produce) { // top level node in the operator tree is a produce (return) diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index 64c7ddb76..5d31de980 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -9,22 +9,22 @@ namespace query { class Tree { -public: + public: Tree(int uid) : uid_(uid) {} int uid() const { return uid_; } virtual void Accept(TreeVisitorBase &visitor) = 0; -private: + private: const int uid_; }; class Expression : public Tree { -public: + public: Expression(int uid) : Tree(uid) {} }; class Identifier : public Expression { -public: + public: Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {} void Accept(TreeVisitorBase &visitor) override { @@ -35,8 +35,31 @@ public: std::string name_; }; +class PropertyLookup : public Expression { + public: + PropertyLookup(int uid, std::shared_ptr<Expression> expression, + GraphDb::Property property) + : Expression(uid), expression_(expression), property_(property) {} + + void Accept(TreeVisitorBase &visitor) override { + visitor.Visit(*this); + expression_->Accept(visitor); + visitor.PostVisit(*this); + } + + std::shared_ptr<Expression> + expression_; // vertex or edge, what if map literal??? + GraphDb::Property property_; + // TODO potential problem: property lookups are allowed on both map literals + // and records, but map literals have strings as keys and records have + // GraphDb::Property + // + // possible solution: store both string and GraphDb::Property here and choose + // between the two depending on Expression result +}; + class NamedExpression : public Tree { -public: + public: NamedExpression(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); @@ -49,12 +72,12 @@ public: }; class PatternAtom : public Tree { -public: + public: PatternAtom(int uid) : Tree(uid) {} }; class NodeAtom : public PatternAtom { -public: + public: NodeAtom(int uid) : PatternAtom(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); @@ -67,7 +90,7 @@ public: }; class EdgeAtom : public PatternAtom { -public: + public: enum class Direction { LEFT, RIGHT, BOTH }; EdgeAtom(int uid) : PatternAtom(uid) {} @@ -82,12 +105,12 @@ public: }; class Clause : public Tree { -public: + public: Clause(int uid) : Tree(uid) {} }; class Pattern : public Tree { -public: + public: Pattern(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); @@ -101,7 +124,7 @@ public: }; class Query : public Tree { -public: + public: Query(int uid) : Tree(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); @@ -114,7 +137,7 @@ public: }; class Match : public Clause { -public: + public: Match(int uid) : Clause(uid) {} std::vector<std::shared_ptr<Pattern>> patterns_; void Accept(TreeVisitorBase &visitor) override { @@ -127,7 +150,7 @@ public: }; class Return : public Clause { -public: + public: Return(int uid) : Clause(uid) {} void Accept(TreeVisitorBase &visitor) override { visitor.Visit(*this); diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index 0bb6f9948..721441031 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -6,6 +6,7 @@ namespace query { class Query; class NamedExpression; class Identifier; +class PropertyLookup; class Match; class Return; class Pattern; @@ -23,6 +24,9 @@ public: virtual void PostVisit(NamedExpression&) {} virtual void Visit(Identifier&) {} virtual void PostVisit(Identifier&) {} + virtual void PreVisit(PropertyLookup&) {} + virtual void Visit(PropertyLookup&) {} + virtual void PostVisit(PropertyLookup&) {} // Clauses virtual void Visit(Match&) {} virtual void PostVisit(Match&) {} diff --git a/src/query/frontend/interpret/interpret.hpp b/src/query/frontend/interpret/interpret.hpp index ad9a3dbe5..8c2e21026 100644 --- a/src/query/frontend/interpret/interpret.hpp +++ b/src/query/frontend/interpret/interpret.hpp @@ -1,11 +1,12 @@ #pragma once #include <vector> +#include <utils/exceptions/not_yet_implemented.hpp> -#include "utils/assert.hpp" #include "query/backend/cpp/typed_value.hpp" #include "query/frontend/ast/ast.hpp" #include "query/frontend/semantic/symbol_table.hpp" +#include "utils/assert.hpp" namespace query { @@ -26,21 +27,56 @@ class ExpressionEvaluator : public TreeVisitorBase { ExpressionEvaluator(Frame &frame, SymbolTable &symbol_table) : frame_(frame), symbol_table_(symbol_table) {} + /** + * Removes and returns the last value from the result stack. + * Consumers of this function are PostVisit functions for + * expressions that consume subexpressions, as well as top + * level expression consumers. + */ + auto PopBack() { + debug_assert(result_stack_.size() > 0, "Result stack empty"); + auto last = result_stack_.back(); + result_stack_.pop_back(); + return last; + } + void PostVisit(NamedExpression &named_expression) override { - auto &symbol = symbol_table_[named_expression]; - debug_assert(!result_stack_.empty(), - "The result of evaluating a named expression is missing."); - frame_[symbol.position_] = result_stack_.back(); + auto symbol = symbol_table_[named_expression]; + frame_[symbol.position_] = PopBack(); } void Visit(Identifier &ident) override { result_stack_.push_back(frame_[symbol_table_[ident].position_]); } + void PostVisit(PropertyLookup &property_lookup) override { + auto expression_result = PopBack(); + switch (expression_result.type()) { + case TypedValue::Type::Vertex: + result_stack_.emplace_back( + expression_result.Value<VertexAccessor>().PropsAt( + property_lookup.property_)); + break; + case TypedValue::Type::Edge: + result_stack_.emplace_back( + expression_result.Value<EdgeAccessor>().PropsAt( + property_lookup.property_)); + break; + + case TypedValue::Type::Map: + // TODO implement me + throw NotYetImplemented(); + break; + + default: + throw TypedValueException( + "Expected Node, Edge or Map for property lookup"); + } + } + private: Frame &frame_; SymbolTable &symbol_table_; std::list<TypedValue> result_stack_; }; - } diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp index 204e71dc4..8c07916cd 100644 --- a/src/query/frontend/logical/operator.hpp +++ b/src/query/frontend/logical/operator.hpp @@ -40,30 +40,16 @@ class ScanAll : public LogicalOperator { vertices_it_(vertices_.begin()) {} bool Pull(Frame& frame, SymbolTable& symbol_table) override { - while (vertices_it_ != vertices_.end()) { - auto vertex = *vertices_it_++; - if (Evaluate(frame, symbol_table, vertex)) { - return true; - } - } - return false; + if (vertices_it_ == vertices_.end()) return false; + frame[symbol_table[*self_.node_atom->identifier_].position_] = + *vertices_it_++; + return true; } private: ScanAll& self_; decltype(std::declval<GraphDbAccessor>().vertices()) vertices_; decltype(vertices_.begin()) vertices_it_; - - bool Evaluate(Frame& frame, SymbolTable& symbol_table, - VertexAccessor& vertex) { - auto node_atom = self_.node_atom; - for (auto label : node_atom->labels_) { - // TODO: Move this to filter operator - if (!vertex.has_label(label)) return false; - } - frame[symbol_table[*node_atom->identifier_].position_] = vertex; - return true; - } }; public: @@ -72,10 +58,70 @@ class ScanAll : public LogicalOperator { } private: - friend class ScanAll::ScanAllCursor; std::shared_ptr<NodeAtom> node_atom; }; +class NodeFilter : public LogicalOperator { + public: + NodeFilter( + std::shared_ptr<LogicalOperator> input, Symbol input_symbol, + std::vector<GraphDb::Label> labels, + std::map<GraphDb::Property, std::shared_ptr<Expression>> properties) + : input_(input), + input_symbol_(input_symbol), + labels_(labels), + properties_(properties) {} + + private: + class NodeFilterCursor : public Cursor { + public: + NodeFilterCursor(NodeFilter& self, GraphDbAccessor& db) + : self_(self), input_cursor_(self_.input_->MakeCursor(db)) {} + + bool Pull(Frame& frame, SymbolTable& symbol_table) override { + while (input_cursor_->Pull(frame, symbol_table)) { + const VertexAccessor& vertex = + frame[self_.input_symbol_.position_].Value<VertexAccessor>(); + if (VertexPasses(vertex, frame, symbol_table)) return true; + } + return false; + } + + private: + NodeFilter& self_; + std::unique_ptr<Cursor> input_cursor_; + + bool VertexPasses(const VertexAccessor& vertex, Frame& frame, + SymbolTable& symbol_table) { + for (auto label : self_.labels_) + if (!vertex.has_label(label)) return false; + + ExpressionEvaluator expression_evaluator(frame, symbol_table); + for (auto prop_pair : self_.properties_) { + prop_pair.second->Accept(expression_evaluator); + TypedValue comparison_result = + vertex.PropsAt(prop_pair.first) == expression_evaluator.PopBack(); + if (comparison_result.type() == TypedValue::Type::Null || + !comparison_result.Value<bool>()) + return false; + } + + return true; + } + }; + + public: + std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor& db) override { + return std::make_unique<NodeFilterCursor>(*this, db); + } + + private: + std::shared_ptr<LogicalOperator> input_; + const Symbol input_symbol_; + std::vector<GraphDb::Label> labels_; + std::map<GraphDb::Property, std::shared_ptr<Expression>> properties_; +}; + class Produce : public LogicalOperator { public: Produce(std::shared_ptr<LogicalOperator> input, @@ -96,9 +142,9 @@ class Produce : public LogicalOperator { ProduceCursor(Produce& self, GraphDbAccessor& db) : self_(self), self_cursor_(self_.input_->MakeCursor(db)) {} bool Pull(Frame& frame, SymbolTable& symbol_table) override { + ExpressionEvaluator evaluator(frame, symbol_table); if (self_cursor_->Pull(frame, symbol_table)) { for (auto named_expr : self_.named_expressions_) { - ExpressionEvaluator evaluator(frame, symbol_table); named_expr->Accept(evaluator); } return true; diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp new file mode 100644 index 000000000..0406907c1 --- /dev/null +++ b/tests/unit/interpreter.cpp @@ -0,0 +1,228 @@ +// +// Copyright 2017 Memgraph +// Created by Florijan Stamenkovic on 14.03.17. +// + +#include <memory> +#include <vector> + +#include "gtest/gtest.h" + +#include "communication/result_stream_faker.hpp" +#include "dbms/dbms.hpp" +#include "query/entry.hpp" + +using namespace query; + +/** + * Helper function that collects all the results from the given + * Produce into a ResultStreamFaker and returns that object. + * + * @param produce + * @param symbol_table + * @param db_accessor + * @return + */ +auto CollectProduce(std::shared_ptr<Produce> produce, SymbolTable &symbol_table, + GraphDbAccessor &db_accessor) { + ResultStreamFaker stream; + Frame frame(symbol_table.max_position()); + + // top level node in the operator tree is a produce (return) + // so stream out results + + // generate header + std::vector<std::string> header; + for (auto named_expression : produce->named_expressions()) + header.push_back(named_expression->name_); + stream.Header(header); + + // collect the symbols from the return clause + std::vector<Symbol> symbols; + for (auto named_expression : produce->named_expressions()) + symbols.emplace_back(symbol_table[*named_expression]); + + // stream out results + auto cursor = produce->MakeCursor(db_accessor); + while (cursor->Pull(frame, symbol_table)) { + std::vector<TypedValue> values; + for (auto &symbol : symbols) values.emplace_back(frame[symbol.position_]); + stream.Result(values); + } + + stream.Summary({{std::string("type"), TypedValue("r")}}); + + return stream; +} + +/* + * Following are helper functions that create high level AST + * and logical operator objects. + */ + +auto MakeNamedExpression(Context &ctx, const std::string name, + std::shared_ptr<Expression> expression) { + auto named_expression = std::make_shared<NamedExpression>(ctx.next_uid()); + named_expression->name_ = name; + named_expression->expression_ = expression; + return named_expression; +} + +auto MakeIdentifier(Context &ctx, const std::string name) { + return std::make_shared<Identifier>(ctx.next_uid(), name); +} + +auto MakeNode(Context &ctx, std::shared_ptr<Identifier> identifier) { + auto node = std::make_shared<NodeAtom>(ctx.next_uid()); + node->identifier_ = identifier; + return node; +} + +auto MakeScanAll(std::shared_ptr<NodeAtom> node_atom) { + return std::make_shared<ScanAll>(node_atom); +} + +template <typename... TNamedExpressions> +auto MakeProduce(std::shared_ptr<LogicalOperator> input, + TNamedExpressions... named_expressions) { + return std::make_shared<Produce>( + input, + std::vector<std::shared_ptr<NamedExpression>>{named_expressions...}); +} + +/* + * Actual tests start here. + */ + +TEST(Interpreter, MatchReturn) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + dba->insert_vertex(); + dba->insert_vertex(); + + Config config; + Context ctx(config, *dba); + + // make a scan all + auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + auto scan_all = MakeScanAll(node); + + // make a named expression and a produce + auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n")); + auto produce = MakeProduce(scan_all, output); + + // fill up the symbol table + SymbolTable symbol_table; + auto n_symbol = symbol_table.CreateSymbol("n"); + symbol_table[*node->identifier_] = n_symbol; + symbol_table[*output->expression_] = n_symbol; + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 2); +} + +TEST(Interpreter, NodeFilterLabelsAndProperties) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + GraphDb::Label label = dba->label("Label"); + GraphDb::Property property = dba->property("Property"); + auto v1 = dba->insert_vertex(); + auto v2 = dba->insert_vertex(); + auto v3 = dba->insert_vertex(); + auto v4 = dba->insert_vertex(); + auto v5 = dba->insert_vertex(); + dba->insert_vertex(); + // test all combination of (label | no_label) * (no_prop | wrong_prop | right_prop) + // only v1 will have the right labels + v1.add_label(label); + v2.add_label(label); + v3.add_label(label); + v1.PropsSet(property, 42); + v2.PropsSet(property, 1); + v4.PropsSet(property, 42); + v5.PropsSet(property, 1); + + Config config; + Context ctx(config, *dba); + + // make a scan all + auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + auto scan_all = MakeScanAll(node); + + // node filtering + SymbolTable symbol_table; + auto n_symbol = symbol_table.CreateSymbol("n"); + // TODO implement the test once int-literal expressions are available + auto node_filter = std::make_shared<NodeFilter>( + scan_all, n_symbol, std::vector<GraphDb::Label>{label}, + std::map<GraphDb::Property, std::shared_ptr<Expression>>()); + + // make a named expression and a produce + auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n")); + auto produce = MakeProduce(node_filter, output); + + // fill up the symbol table + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*node->identifier_] = n_symbol; + symbol_table[*output->expression_] = n_symbol; + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 1); +} + +TEST(Interpreter, NodeFilterMultipleLabels) { + Dbms dbms; + auto dba = dbms.active(); + + // add a few nodes to the database + GraphDb::Label label1 = dba->label("label1"); + GraphDb::Label label2 = dba->label("label2"); + GraphDb::Label label3 = dba->label("label3"); + // the test will look for nodes that have label1 and label2 + dba->insert_vertex(); // NOT accepted + dba->insert_vertex().add_label(label1); // NOT accepted + dba->insert_vertex().add_label(label2); // NOT accepted + dba->insert_vertex().add_label(label3); // NOT accepted + auto v1 = dba->insert_vertex(); // YES accepted + v1.add_label(label1); + v1.add_label(label2); + auto v2 = dba->insert_vertex(); // NOT accepted + v2.add_label(label1); + v2.add_label(label3); + auto v3 = dba->insert_vertex(); // YES accepted + v3.add_label(label1); + v3.add_label(label2); + v3.add_label(label3); + + Config config; + Context ctx(config, *dba); + + // make a scan all + auto node = MakeNode(ctx, MakeIdentifier(ctx, "n")); + auto scan_all = MakeScanAll(node); + + // node filtering + SymbolTable symbol_table; + auto n_symbol = symbol_table.CreateSymbol("n"); + // TODO implement the test once int-literal expressions are available + auto node_filter = std::make_shared<NodeFilter>( + scan_all, n_symbol, std::vector<GraphDb::Label>{label1, label2}, + std::map<GraphDb::Property, std::shared_ptr<Expression>>()); + + // make a named expression and a produce + auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n")); + auto produce = MakeProduce(node_filter, output); + + // fill up the symbol table + symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1"); + symbol_table[*node->identifier_] = n_symbol; + symbol_table[*output->expression_] = n_symbol; + + ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba); + EXPECT_EQ(result.GetResults().size(), 2); +}