From c1e4676316e77634b6429995caa166aa1618bcc5 Mon Sep 17 00:00:00 2001
From: Teon Banek <teon.banek@memgraph.io>
Date: Fri, 9 Feb 2018 16:16:29 +0100
Subject: [PATCH] Add optional total_weight variable to wShortest grammar

Reviewers: florijan, msantl, buda

Reviewed By: msantl

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1187
---
 src/query/frontend/ast/ast.hpp                | 10 +++++++++-
 .../frontend/ast/cypher_main_visitor.cpp      | 19 +++++++++++++++++--
 .../frontend/opencypher/grammar/Cypher.g4     |  6 +++---
 .../frontend/semantic/symbol_generator.cpp    |  8 ++++++++
 tests/unit/cypher_main_visitor.cpp            |  9 +++++++--
 tests/unit/query_semantic.cpp                 |  6 ++++--
 6 files changed, 48 insertions(+), 10 deletions(-)

diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp
index e20312c95..7e567bece 100644
--- a/src/query/frontend/ast/ast.hpp
+++ b/src/query/frontend/ast/ast.hpp
@@ -1768,7 +1768,10 @@ class EdgeAtom : public PatternAtom {
         cont = lower_bound_->Accept(visitor);
       }
       if (cont && upper_bound_) {
-        upper_bound_->Accept(visitor);
+        cont = upper_bound_->Accept(visitor);
+      }
+      if (cont && total_weight_) {
+        total_weight_->Accept(visitor);
       }
     }
     return visitor.PostVisit(*this);
@@ -1791,6 +1794,7 @@ class EdgeAtom : public PatternAtom {
     };
     edge_atom->filter_lambda_ = clone_lambda(filter_lambda_);
     edge_atom->weight_lambda_ = clone_lambda(weight_lambda_);
+    edge_atom->total_weight_ = CloneOpt(total_weight_, storage);
     return edge_atom;
   }
 
@@ -1825,6 +1829,8 @@ class EdgeAtom : public PatternAtom {
   /// It must have valid expressions and identifiers. In all other expand types,
   /// it is empty.
   Lambda weight_lambda_;
+  /// Variable where the total weight for weighted shortest path will be stored.
+  Identifier *total_weight_ = nullptr;
 
  protected:
   using PatternAtom::PatternAtom;
@@ -1866,6 +1872,7 @@ class EdgeAtom : public PatternAtom {
     };
     save_lambda(filter_lambda_);
     save_lambda(weight_lambda_);
+    SavePointer(ar, total_weight_);
   }
 
   template <class TArchive>
@@ -1894,6 +1901,7 @@ class EdgeAtom : public PatternAtom {
     };
     load_lambda(filter_lambda_);
     load_lambda(weight_lambda_);
+    LoadPointer(ar, total_weight_);
   }
 
   template <class TArchive>
diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp
index b4e30be37..aba2efdc3 100644
--- a/src/query/frontend/ast/cypher_main_visitor.cpp
+++ b/src/query/frontend/ast/cypher_main_visitor.cpp
@@ -480,8 +480,8 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
     return edge;
   }
 
-  if (relationshipDetail->variable()) {
-    std::string variable = relationshipDetail->variable()->accept(this);
+  if (relationshipDetail->name) {
+    std::string variable = relationshipDetail->name->accept(this);
     edge->identifier_ = storage_.Create<Identifier>(variable);
     users_identifiers.insert(variable);
   } else {
@@ -497,6 +497,10 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
 
   auto relationshipLambdas = relationshipDetail->relationshipLambda();
   if (variableExpansion) {
+    if (relationshipDetail->total_weight &&
+        edge->type_ != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH)
+      throw SemanticException(
+          "Variable for total weight is allowed only in wShortest");
     auto visit_lambda = [this](auto *lambda) {
       EdgeAtom::Lambda edge_lambda;
       std::string traversed_edge_variable =
@@ -510,6 +514,15 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
       edge_lambda.expression = lambda->expression()->accept(this);
       return edge_lambda;
     };
+    auto visit_total_weight = [&]() {
+      if (relationshipDetail->total_weight) {
+        std::string total_weight_name =
+            relationshipDetail->total_weight->accept(this);
+        edge->total_weight_ = storage_.Create<Identifier>(total_weight_name);
+      } else {
+        anonymous_identifiers.push_back(&edge->total_weight_);
+      }
+    };
     switch (relationshipLambdas.size()) {
       case 0:
         if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH)
@@ -524,6 +537,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
           // For wShortest, the first (and required) lambda is used for weight
           // calculation.
           edge->weight_lambda_ = visit_lambda(relationshipLambdas[0]);
+          visit_total_weight();
           // Add mandatory inner variables for filter lambda.
           anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge);
           anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node);
@@ -536,6 +550,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern(
         if (edge->type_ != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH)
           throw SemanticException("Only one relationship lambda allowed");
         edge->weight_lambda_ = visit_lambda(relationshipLambdas[0]);
+        visit_total_weight();
         edge->filter_lambda_ = visit_lambda(relationshipLambdas[1]);
         break;
       default:
diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4
index 8737f4ecf..7d0a35137 100644
--- a/src/query/frontend/opencypher/grammar/Cypher.g4
+++ b/src/query/frontend/opencypher/grammar/Cypher.g4
@@ -123,9 +123,9 @@ relationshipPattern : ( leftArrowHead SP? dash SP? ( relationshipDetail )? SP? d
                     | ( dash SP? ( relationshipDetail )? SP? dash )
                     ;
 
-relationshipDetail : '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )?  properties SP? ']'
-                   | '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )?  relationshipLambda SP? (relationshipLambda SP?)? ']'
-                   | '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? ( (properties SP?) | (relationshipLambda SP?) )* ']';
+relationshipDetail : '[' SP? ( name=variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )?  properties SP? ']'
+                   | '[' SP? ( name=variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? relationshipLambda SP? ( total_weight=variable SP? )? (relationshipLambda SP?)? ']'
+                   | '[' SP? ( name=variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? (properties SP?)* ( relationshipLambda SP? total_weight=variable SP? )? (relationshipLambda SP?)? ']';
 
 relationshipLambda: '(' SP? traversed_edge=variable SP? ',' SP? traversed_node=variable SP? '|' SP? expression SP? ')';
 
diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp
index 5413aedfe..2d5fa47fa 100644
--- a/src/query/frontend/semantic/symbol_generator.cpp
+++ b/src/query/frontend/semantic/symbol_generator.cpp
@@ -443,6 +443,14 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) {
   scope_.in_pattern_atom_identifier = true;
   edge_atom.identifier_->Accept(*this);
   scope_.in_pattern_atom_identifier = false;
+  if (edge_atom.total_weight_) {
+    if (HasSymbol(edge_atom.total_weight_->name_)) {
+      throw RedeclareVariableError(edge_atom.total_weight_->name_);
+    }
+    symbol_table_[*edge_atom.total_weight_] = GetOrCreateSymbol(
+        edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_,
+        Symbol::Type::Number);
+  }
   return false;
 }
 
diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp
index ce237e40e..244e3ed6d 100644
--- a/tests/unit/cypher_main_visitor.cpp
+++ b/tests/unit/cypher_main_visitor.cpp
@@ -1632,8 +1632,8 @@ TYPED_TEST(CypherMainVisitorTest, MatchVariableLambdaSymbols) {
 
 TYPED_TEST(CypherMainVisitorTest, MatchWShortestReturn) {
   TypeParam ast_generator(
-      "MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42) (e, n | true)]->() "
-      "RETURN r");
+      "MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42) total_weight "
+      "(e, n | true)]->() RETURN r");
   auto *query = ast_generator.query_;
   ASSERT_TRUE(query->single_query_);
   auto *single_query = query->single_query_;
@@ -1665,6 +1665,9 @@ TYPED_TEST(CypherMainVisitorTest, MatchWShortestReturn) {
   EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn");
   EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_);
   CheckLiteral(ast_generator.context_, shortest->weight_lambda_.expression, 42);
+  ASSERT_TRUE(shortest->total_weight_);
+  EXPECT_EQ(shortest->total_weight_->name_, "total_weight");
+  EXPECT_TRUE(shortest->total_weight_->user_declared_);
 }
 
 TYPED_TEST(CypherMainVisitorTest, MatchWShortestNoFilterReturn) {
@@ -1699,6 +1702,8 @@ TYPED_TEST(CypherMainVisitorTest, MatchWShortestNoFilterReturn) {
   EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn");
   EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_);
   CheckLiteral(ast_generator.context_, shortest->weight_lambda_.expression, 42);
+  ASSERT_TRUE(shortest->total_weight_);
+  EXPECT_FALSE(shortest->total_weight_->user_declared_);
 }
 
 
diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp
index b0581bc53..09839c216 100644
--- a/tests/unit/query_semantic.cpp
+++ b/tests/unit/query_semantic.cpp
@@ -944,6 +944,7 @@ TEST_F(TestSymbolGenerator, MatchWShortestReturn) {
     shortest->weight_lambda_.inner_edge = IDENT("r");
     shortest->weight_lambda_.inner_node = IDENT("n");
     shortest->weight_lambda_.expression = r_weight;
+    shortest->total_weight_ = IDENT("total_weight");
   }
   {
     shortest->filter_lambda_.inner_edge = IDENT("r");
@@ -954,8 +955,9 @@ TEST_F(TestSymbolGenerator, MatchWShortestReturn) {
   auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n, shortest, NODE("m"))),
                                    RETURN(ret_r, AS("r"))));
   query->Accept(symbol_generator);
-  // Symbols for pattern, `n`, `[r]`, (`r|`, `n|`)x2, `m` and `AS r`.
-  EXPECT_EQ(symbol_table.max_position(), 9);
+  // Symbols for pattern, `n`, `[r]`, `total_weight`, (`r|`, `n|`)x2, `m` and
+  // `AS r`.
+  EXPECT_EQ(symbol_table.max_position(), 10);
   EXPECT_EQ(symbol_table.at(*ret_r), symbol_table.at(*shortest->identifier_));
   EXPECT_NE(symbol_table.at(*ret_r),
             symbol_table.at(*shortest->weight_lambda_.inner_edge));