From 6c7372b3c507ee4da9e09bd116f19d3f9b0bf389 Mon Sep 17 00:00:00 2001
From: florijan <florijan@memgraph.io>
Date: Wed, 15 Mar 2017 15:49:19 +0100
Subject: [PATCH] Query - AST tests in progress

Summary:
Query compiler AST test in progress
Logical operator testing. NodeFilter LogicalOperator added

Reviewers: mislav.bradac, buda, teon.banek

Reviewed By: buda, teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D126
---
 src/communication/result_stream_faker.hpp  |  12 +-
 src/query/entry.hpp                        |  10 +-
 src/query/frontend/ast/ast.hpp             |  49 +++--
 src/query/frontend/ast/ast_visitor.hpp     |   4 +
 src/query/frontend/interpret/interpret.hpp |  48 ++++-
 src/query/frontend/logical/operator.hpp    |  86 ++++++--
 tests/unit/interpreter.cpp                 | 228 +++++++++++++++++++++
 7 files changed, 388 insertions(+), 49 deletions(-)
 create mode 100644 tests/unit/interpreter.cpp

diff --git a/src/communication/result_stream_faker.hpp b/src/communication/result_stream_faker.hpp
index 8de4cf7b9..2bb709f80 100644
--- a/src/communication/result_stream_faker.hpp
+++ b/src/communication/result_stream_faker.hpp
@@ -13,20 +13,20 @@
 class ResultStreamFaker {
  public:
 
-  void Header(std::vector<std::string> &&fields) {
+  void Header(const std::vector<std::string> &fields) {
     debug_assert(current_state_ == State::Start, "Headers can only be written in the beginning");
-    header_ = std::forward(fields);
+    header_ = fields;
     current_state_ = State::WritingResults;
   }
 
-  void Result(std::vector<TypedValue> &&values) {
+  void Result(const std::vector<TypedValue> &values) {
     debug_assert(current_state_ == State::WritingResults, "Can't accept results before header nor after summary");
-    results_.push_back(std::forward(values));
+    results_.push_back(values);
   }
 
-  void Summary(std::map<std::string, TypedValue> &&summary) {
+  void Summary(const std::map<std::string, TypedValue> &summary) {
     debug_assert(current_state_ != State::Done, "Can only send a summary once");
-    summary_ = std::forward(summary);
+    summary_ = summary;
     current_state_ = State::Done;
   }
 
diff --git a/src/query/entry.hpp b/src/query/entry.hpp
index 43bc9899b..8a2927793 100644
--- a/src/query/entry.hpp
+++ b/src/query/entry.hpp
@@ -6,7 +6,6 @@
 #include "query/frontend/interpret/interpret.hpp"
 #include "query/frontend/logical/planner.hpp"
 #include "query/frontend/opencypher/parser.hpp"
-#include "query/frontend/semantic/symbol_table.hpp"
 #include "query/frontend/semantic/symbol_generator.hpp"
 
 namespace query {
@@ -50,19 +49,22 @@ class Engine {
     // AST -> high level tree
     HighLevelAstConversion low2high_tree;
     auto high_level_tree = low2high_tree.Apply(ctx, low_level_tree);
+    Execute(*high_level_tree, db_accessor, stream);
+  }
 
+  auto Execute(Query &query, GraphDbAccessor &db_accessor, Stream stream) {
     // symbol table fill
     SymbolTable symbol_table;
     SymbolGenerator symbol_generator(symbol_table);
-    high_level_tree->Accept(symbol_generator);
+    query.Accept(symbol_generator);
 
     // high level tree -> logical plan
-    auto logical_plan = MakeLogicalPlan(*high_level_tree);
+    auto logical_plan = MakeLogicalPlan(query);
 
     // generate frame based on symbol table max_position
     Frame frame(symbol_table.max_position());
 
-    auto *produce = dynamic_cast<Produce*>(logical_plan.get());
+    auto *produce = dynamic_cast<Produce *>(logical_plan.get());
 
     if (produce) {
       // top level node in the operator tree is a produce (return)
diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp
index 64c7ddb76..5d31de980 100644
--- a/src/query/frontend/ast/ast.hpp
+++ b/src/query/frontend/ast/ast.hpp
@@ -9,22 +9,22 @@
 namespace query {
 
 class Tree {
-public:
+ public:
   Tree(int uid) : uid_(uid) {}
   int uid() const { return uid_; }
   virtual void Accept(TreeVisitorBase &visitor) = 0;
 
-private:
+ private:
   const int uid_;
 };
 
 class Expression : public Tree {
-public:
+ public:
   Expression(int uid) : Tree(uid) {}
 };
 
 class Identifier : public Expression {
-public:
+ public:
   Identifier(int uid, const std::string &name) : Expression(uid), name_(name) {}
 
   void Accept(TreeVisitorBase &visitor) override {
@@ -35,8 +35,31 @@ public:
   std::string name_;
 };
 
+class PropertyLookup : public Expression {
+ public:
+  PropertyLookup(int uid, std::shared_ptr<Expression> expression,
+                 GraphDb::Property property)
+      : Expression(uid), expression_(expression), property_(property) {}
+
+  void Accept(TreeVisitorBase &visitor) override {
+    visitor.Visit(*this);
+    expression_->Accept(visitor);
+    visitor.PostVisit(*this);
+  }
+
+  std::shared_ptr<Expression>
+      expression_;  // vertex or edge, what if map literal???
+  GraphDb::Property property_;
+  // TODO potential problem: property lookups are allowed on both map literals
+  // and records, but map literals have strings as keys and records have
+  // GraphDb::Property
+  //
+  // possible solution: store both string and GraphDb::Property here and choose
+  // between the two depending on Expression result
+};
+
 class NamedExpression : public Tree {
-public:
+ public:
   NamedExpression(int uid) : Tree(uid) {}
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
@@ -49,12 +72,12 @@ public:
 };
 
 class PatternAtom : public Tree {
-public:
+ public:
   PatternAtom(int uid) : Tree(uid) {}
 };
 
 class NodeAtom : public PatternAtom {
-public:
+ public:
   NodeAtom(int uid) : PatternAtom(uid) {}
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
@@ -67,7 +90,7 @@ public:
 };
 
 class EdgeAtom : public PatternAtom {
-public:
+ public:
   enum class Direction { LEFT, RIGHT, BOTH };
 
   EdgeAtom(int uid) : PatternAtom(uid) {}
@@ -82,12 +105,12 @@ public:
 };
 
 class Clause : public Tree {
-public:
+ public:
   Clause(int uid) : Tree(uid) {}
 };
 
 class Pattern : public Tree {
-public:
+ public:
   Pattern(int uid) : Tree(uid) {}
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
@@ -101,7 +124,7 @@ public:
 };
 
 class Query : public Tree {
-public:
+ public:
   Query(int uid) : Tree(uid) {}
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
@@ -114,7 +137,7 @@ public:
 };
 
 class Match : public Clause {
-public:
+ public:
   Match(int uid) : Clause(uid) {}
   std::vector<std::shared_ptr<Pattern>> patterns_;
   void Accept(TreeVisitorBase &visitor) override {
@@ -127,7 +150,7 @@ public:
 };
 
 class Return : public Clause {
-public:
+ public:
   Return(int uid) : Clause(uid) {}
   void Accept(TreeVisitorBase &visitor) override {
     visitor.Visit(*this);
diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp
index 0bb6f9948..721441031 100644
--- a/src/query/frontend/ast/ast_visitor.hpp
+++ b/src/query/frontend/ast/ast_visitor.hpp
@@ -6,6 +6,7 @@ namespace query {
 class Query;
 class NamedExpression;
 class Identifier;
+class PropertyLookup;
 class Match;
 class Return;
 class Pattern;
@@ -23,6 +24,9 @@ public:
   virtual void PostVisit(NamedExpression&) {}
   virtual void Visit(Identifier&) {}
   virtual void PostVisit(Identifier&) {}
+  virtual void PreVisit(PropertyLookup&) {}
+  virtual void Visit(PropertyLookup&) {}
+  virtual void PostVisit(PropertyLookup&) {}
   // Clauses
   virtual void Visit(Match&) {}
   virtual void PostVisit(Match&) {}
diff --git a/src/query/frontend/interpret/interpret.hpp b/src/query/frontend/interpret/interpret.hpp
index ad9a3dbe5..8c2e21026 100644
--- a/src/query/frontend/interpret/interpret.hpp
+++ b/src/query/frontend/interpret/interpret.hpp
@@ -1,11 +1,12 @@
 #pragma once
 
 #include <vector>
+#include <utils/exceptions/not_yet_implemented.hpp>
 
-#include "utils/assert.hpp"
 #include "query/backend/cpp/typed_value.hpp"
 #include "query/frontend/ast/ast.hpp"
 #include "query/frontend/semantic/symbol_table.hpp"
+#include "utils/assert.hpp"
 
 namespace query {
 
@@ -26,21 +27,56 @@ class ExpressionEvaluator : public TreeVisitorBase {
   ExpressionEvaluator(Frame &frame, SymbolTable &symbol_table)
       : frame_(frame), symbol_table_(symbol_table) {}
 
+  /**
+   * Removes and returns the last value from the result stack.
+   * Consumers of this function are PostVisit functions for
+   * expressions that consume subexpressions, as well as top
+   * level expression consumers.
+   */
+  auto PopBack() {
+    debug_assert(result_stack_.size() > 0, "Result stack empty");
+    auto last = result_stack_.back();
+    result_stack_.pop_back();
+    return last;
+  }
+
   void PostVisit(NamedExpression &named_expression) override {
-    auto &symbol = symbol_table_[named_expression];
-    debug_assert(!result_stack_.empty(),
-                 "The result of evaluating a named expression is missing.");
-    frame_[symbol.position_] = result_stack_.back();
+    auto symbol = symbol_table_[named_expression];
+    frame_[symbol.position_] = PopBack();
   }
 
   void Visit(Identifier &ident) override {
     result_stack_.push_back(frame_[symbol_table_[ident].position_]);
   }
 
+  void PostVisit(PropertyLookup &property_lookup) override {
+    auto expression_result = PopBack();
+    switch (expression_result.type()) {
+      case TypedValue::Type::Vertex:
+        result_stack_.emplace_back(
+            expression_result.Value<VertexAccessor>().PropsAt(
+                property_lookup.property_));
+        break;
+      case TypedValue::Type::Edge:
+        result_stack_.emplace_back(
+            expression_result.Value<EdgeAccessor>().PropsAt(
+                property_lookup.property_));
+        break;
+
+      case TypedValue::Type::Map:
+        // TODO implement me
+        throw NotYetImplemented();
+        break;
+
+      default:
+        throw TypedValueException(
+            "Expected Node, Edge or Map for property lookup");
+    }
+  }
+
  private:
   Frame &frame_;
   SymbolTable &symbol_table_;
   std::list<TypedValue> result_stack_;
 };
-
 }
diff --git a/src/query/frontend/logical/operator.hpp b/src/query/frontend/logical/operator.hpp
index 204e71dc4..8c07916cd 100644
--- a/src/query/frontend/logical/operator.hpp
+++ b/src/query/frontend/logical/operator.hpp
@@ -40,30 +40,16 @@ class ScanAll : public LogicalOperator {
           vertices_it_(vertices_.begin()) {}
 
     bool Pull(Frame& frame, SymbolTable& symbol_table) override {
-      while (vertices_it_ != vertices_.end()) {
-        auto vertex = *vertices_it_++;
-        if (Evaluate(frame, symbol_table, vertex)) {
-          return true;
-        }
-      }
-      return false;
+      if (vertices_it_ == vertices_.end()) return false;
+      frame[symbol_table[*self_.node_atom->identifier_].position_] =
+          *vertices_it_++;
+      return true;
     }
 
    private:
     ScanAll& self_;
     decltype(std::declval<GraphDbAccessor>().vertices()) vertices_;
     decltype(vertices_.begin()) vertices_it_;
-
-    bool Evaluate(Frame& frame, SymbolTable& symbol_table,
-                  VertexAccessor& vertex) {
-      auto node_atom = self_.node_atom;
-      for (auto label : node_atom->labels_) {
-        // TODO: Move this to filter operator
-        if (!vertex.has_label(label)) return false;
-      }
-      frame[symbol_table[*node_atom->identifier_].position_] = vertex;
-      return true;
-    }
   };
 
  public:
@@ -72,10 +58,70 @@ class ScanAll : public LogicalOperator {
   }
 
  private:
-  friend class ScanAll::ScanAllCursor;
   std::shared_ptr<NodeAtom> node_atom;
 };
 
+class NodeFilter : public LogicalOperator {
+ public:
+  NodeFilter(
+      std::shared_ptr<LogicalOperator> input, Symbol input_symbol,
+      std::vector<GraphDb::Label> labels,
+      std::map<GraphDb::Property, std::shared_ptr<Expression>> properties)
+      : input_(input),
+        input_symbol_(input_symbol),
+        labels_(labels),
+        properties_(properties) {}
+
+ private:
+  class NodeFilterCursor : public Cursor {
+   public:
+    NodeFilterCursor(NodeFilter& self, GraphDbAccessor& db)
+        : self_(self), input_cursor_(self_.input_->MakeCursor(db)) {}
+
+    bool Pull(Frame& frame, SymbolTable& symbol_table) override {
+      while (input_cursor_->Pull(frame, symbol_table)) {
+        const VertexAccessor& vertex =
+            frame[self_.input_symbol_.position_].Value<VertexAccessor>();
+        if (VertexPasses(vertex, frame, symbol_table)) return true;
+      }
+      return false;
+    }
+
+   private:
+    NodeFilter& self_;
+    std::unique_ptr<Cursor> input_cursor_;
+
+    bool VertexPasses(const VertexAccessor& vertex, Frame& frame,
+                      SymbolTable& symbol_table) {
+      for (auto label : self_.labels_)
+        if (!vertex.has_label(label)) return false;
+
+      ExpressionEvaluator expression_evaluator(frame, symbol_table);
+      for (auto prop_pair : self_.properties_) {
+        prop_pair.second->Accept(expression_evaluator);
+        TypedValue comparison_result =
+            vertex.PropsAt(prop_pair.first) == expression_evaluator.PopBack();
+        if (comparison_result.type() == TypedValue::Type::Null ||
+            !comparison_result.Value<bool>())
+          return false;
+      }
+
+      return true;
+    }
+  };
+
+ public:
+  std::unique_ptr<Cursor> MakeCursor(GraphDbAccessor& db) override {
+    return std::make_unique<NodeFilterCursor>(*this, db);
+  }
+
+ private:
+  std::shared_ptr<LogicalOperator> input_;
+  const Symbol input_symbol_;
+  std::vector<GraphDb::Label> labels_;
+  std::map<GraphDb::Property, std::shared_ptr<Expression>> properties_;
+};
+
 class Produce : public LogicalOperator {
  public:
   Produce(std::shared_ptr<LogicalOperator> input,
@@ -96,9 +142,9 @@ class Produce : public LogicalOperator {
     ProduceCursor(Produce& self, GraphDbAccessor& db)
         : self_(self), self_cursor_(self_.input_->MakeCursor(db)) {}
     bool Pull(Frame& frame, SymbolTable& symbol_table) override {
+      ExpressionEvaluator evaluator(frame, symbol_table);
       if (self_cursor_->Pull(frame, symbol_table)) {
         for (auto named_expr : self_.named_expressions_) {
-          ExpressionEvaluator evaluator(frame, symbol_table);
           named_expr->Accept(evaluator);
         }
         return true;
diff --git a/tests/unit/interpreter.cpp b/tests/unit/interpreter.cpp
new file mode 100644
index 000000000..0406907c1
--- /dev/null
+++ b/tests/unit/interpreter.cpp
@@ -0,0 +1,228 @@
+//
+// Copyright 2017 Memgraph
+// Created by Florijan Stamenkovic on 14.03.17.
+//
+
+#include <memory>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "communication/result_stream_faker.hpp"
+#include "dbms/dbms.hpp"
+#include "query/entry.hpp"
+
+using namespace query;
+
+/**
+ * Helper function that collects all the results from the given
+ * Produce into a ResultStreamFaker and returns that object.
+ *
+ * @param produce
+ * @param symbol_table
+ * @param db_accessor
+ * @return
+ */
+auto CollectProduce(std::shared_ptr<Produce> produce, SymbolTable &symbol_table,
+                    GraphDbAccessor &db_accessor) {
+  ResultStreamFaker stream;
+  Frame frame(symbol_table.max_position());
+
+  // top level node in the operator tree is a produce (return)
+  // so stream out results
+
+  // generate header
+  std::vector<std::string> header;
+  for (auto named_expression : produce->named_expressions())
+    header.push_back(named_expression->name_);
+  stream.Header(header);
+
+  // collect the symbols from the return clause
+  std::vector<Symbol> symbols;
+  for (auto named_expression : produce->named_expressions())
+    symbols.emplace_back(symbol_table[*named_expression]);
+
+  // stream out results
+  auto cursor = produce->MakeCursor(db_accessor);
+  while (cursor->Pull(frame, symbol_table)) {
+    std::vector<TypedValue> values;
+    for (auto &symbol : symbols) values.emplace_back(frame[symbol.position_]);
+    stream.Result(values);
+  }
+
+  stream.Summary({{std::string("type"), TypedValue("r")}});
+
+  return stream;
+}
+
+/*
+ * Following are helper functions that create high level AST
+ * and logical operator objects.
+ */
+
+auto MakeNamedExpression(Context &ctx, const std::string name,
+                         std::shared_ptr<Expression> expression) {
+  auto named_expression = std::make_shared<NamedExpression>(ctx.next_uid());
+  named_expression->name_ = name;
+  named_expression->expression_ = expression;
+  return named_expression;
+}
+
+auto MakeIdentifier(Context &ctx, const std::string name) {
+  return std::make_shared<Identifier>(ctx.next_uid(), name);
+}
+
+auto MakeNode(Context &ctx, std::shared_ptr<Identifier> identifier) {
+  auto node = std::make_shared<NodeAtom>(ctx.next_uid());
+  node->identifier_ = identifier;
+  return node;
+}
+
+auto MakeScanAll(std::shared_ptr<NodeAtom> node_atom) {
+  return std::make_shared<ScanAll>(node_atom);
+}
+
+template <typename... TNamedExpressions>
+auto MakeProduce(std::shared_ptr<LogicalOperator> input,
+                 TNamedExpressions... named_expressions) {
+  return std::make_shared<Produce>(
+      input,
+      std::vector<std::shared_ptr<NamedExpression>>{named_expressions...});
+}
+
+/*
+ * Actual tests start here.
+ */
+
+TEST(Interpreter, MatchReturn) {
+  Dbms dbms;
+  auto dba = dbms.active();
+
+  // add a few nodes to the database
+  dba->insert_vertex();
+  dba->insert_vertex();
+
+  Config config;
+  Context ctx(config, *dba);
+
+  // make a scan all
+  auto node = MakeNode(ctx, MakeIdentifier(ctx, "n"));
+  auto scan_all = MakeScanAll(node);
+
+  // make a named expression and a produce
+  auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n"));
+  auto produce = MakeProduce(scan_all, output);
+
+  // fill up the symbol table
+  SymbolTable symbol_table;
+  auto n_symbol = symbol_table.CreateSymbol("n");
+  symbol_table[*node->identifier_] = n_symbol;
+  symbol_table[*output->expression_] = n_symbol;
+  symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
+
+  ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
+  EXPECT_EQ(result.GetResults().size(), 2);
+}
+
+TEST(Interpreter, NodeFilterLabelsAndProperties) {
+  Dbms dbms;
+  auto dba = dbms.active();
+
+  // add a few nodes to the database
+  GraphDb::Label label = dba->label("Label");
+  GraphDb::Property property = dba->property("Property");
+  auto v1 = dba->insert_vertex();
+  auto v2 = dba->insert_vertex();
+  auto v3 = dba->insert_vertex();
+  auto v4 = dba->insert_vertex();
+  auto v5 = dba->insert_vertex();
+  dba->insert_vertex();
+  // test all combination of (label | no_label) * (no_prop | wrong_prop | right_prop)
+  // only v1 will have the right labels
+  v1.add_label(label);
+  v2.add_label(label);
+  v3.add_label(label);
+  v1.PropsSet(property, 42);
+  v2.PropsSet(property, 1);
+  v4.PropsSet(property, 42);
+  v5.PropsSet(property, 1);
+
+  Config config;
+  Context ctx(config, *dba);
+
+  // make a scan all
+  auto node = MakeNode(ctx, MakeIdentifier(ctx, "n"));
+  auto scan_all = MakeScanAll(node);
+
+  // node filtering
+  SymbolTable symbol_table;
+  auto n_symbol = symbol_table.CreateSymbol("n");
+  // TODO implement the test once int-literal expressions are available
+  auto node_filter = std::make_shared<NodeFilter>(
+      scan_all, n_symbol, std::vector<GraphDb::Label>{label},
+      std::map<GraphDb::Property, std::shared_ptr<Expression>>());
+
+  // make a named expression and a produce
+  auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n"));
+  auto produce = MakeProduce(node_filter, output);
+
+  // fill up the symbol table
+  symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
+  symbol_table[*node->identifier_] = n_symbol;
+  symbol_table[*output->expression_] = n_symbol;
+
+  ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
+  EXPECT_EQ(result.GetResults().size(), 1);
+}
+
+TEST(Interpreter, NodeFilterMultipleLabels) {
+  Dbms dbms;
+  auto dba = dbms.active();
+
+  // add a few nodes to the database
+  GraphDb::Label label1 = dba->label("label1");
+  GraphDb::Label label2 = dba->label("label2");
+  GraphDb::Label label3 = dba->label("label3");
+  // the test will look for nodes that have label1 and label2
+  dba->insert_vertex(); // NOT accepted
+  dba->insert_vertex().add_label(label1); // NOT accepted
+  dba->insert_vertex().add_label(label2); // NOT accepted
+  dba->insert_vertex().add_label(label3); // NOT accepted
+  auto v1 = dba->insert_vertex(); // YES accepted
+  v1.add_label(label1);
+  v1.add_label(label2);
+  auto v2 = dba->insert_vertex(); // NOT accepted
+  v2.add_label(label1);
+  v2.add_label(label3);
+  auto v3 = dba->insert_vertex(); // YES accepted
+  v3.add_label(label1);
+  v3.add_label(label2);
+  v3.add_label(label3);
+
+  Config config;
+  Context ctx(config, *dba);
+
+  // make a scan all
+  auto node = MakeNode(ctx, MakeIdentifier(ctx, "n"));
+  auto scan_all = MakeScanAll(node);
+
+  // node filtering
+  SymbolTable symbol_table;
+  auto n_symbol = symbol_table.CreateSymbol("n");
+  // TODO implement the test once int-literal expressions are available
+  auto node_filter = std::make_shared<NodeFilter>(
+      scan_all, n_symbol, std::vector<GraphDb::Label>{label1, label2},
+      std::map<GraphDb::Property, std::shared_ptr<Expression>>());
+
+  // make a named expression and a produce
+  auto output = MakeNamedExpression(ctx, "n", MakeIdentifier(ctx, "n"));
+  auto produce = MakeProduce(node_filter, output);
+
+  // fill up the symbol table
+  symbol_table[*output] = symbol_table.CreateSymbol("named_expression_1");
+  symbol_table[*node->identifier_] = n_symbol;
+  symbol_table[*output->expression_] = n_symbol;
+
+  ResultStreamFaker result = CollectProduce(produce, symbol_table, *dba);
+  EXPECT_EQ(result.GetResults().size(), 2);
+}