From c1e4676316e77634b6429995caa166aa1618bcc5 Mon Sep 17 00:00:00 2001 From: Teon Banek <teon.banek@memgraph.io> Date: Fri, 9 Feb 2018 16:16:29 +0100 Subject: [PATCH] Add optional total_weight variable to wShortest grammar Reviewers: florijan, msantl, buda Reviewed By: msantl Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1187 --- src/query/frontend/ast/ast.hpp | 10 +++++++++- .../frontend/ast/cypher_main_visitor.cpp | 19 +++++++++++++++++-- .../frontend/opencypher/grammar/Cypher.g4 | 6 +++--- .../frontend/semantic/symbol_generator.cpp | 8 ++++++++ tests/unit/cypher_main_visitor.cpp | 9 +++++++-- tests/unit/query_semantic.cpp | 6 ++++-- 6 files changed, 48 insertions(+), 10 deletions(-) diff --git a/src/query/frontend/ast/ast.hpp b/src/query/frontend/ast/ast.hpp index e20312c95..7e567bece 100644 --- a/src/query/frontend/ast/ast.hpp +++ b/src/query/frontend/ast/ast.hpp @@ -1768,7 +1768,10 @@ class EdgeAtom : public PatternAtom { cont = lower_bound_->Accept(visitor); } if (cont && upper_bound_) { - upper_bound_->Accept(visitor); + cont = upper_bound_->Accept(visitor); + } + if (cont && total_weight_) { + total_weight_->Accept(visitor); } } return visitor.PostVisit(*this); @@ -1791,6 +1794,7 @@ class EdgeAtom : public PatternAtom { }; edge_atom->filter_lambda_ = clone_lambda(filter_lambda_); edge_atom->weight_lambda_ = clone_lambda(weight_lambda_); + edge_atom->total_weight_ = CloneOpt(total_weight_, storage); return edge_atom; } @@ -1825,6 +1829,8 @@ class EdgeAtom : public PatternAtom { /// It must have valid expressions and identifiers. In all other expand types, /// it is empty. Lambda weight_lambda_; + /// Variable where the total weight for weighted shortest path will be stored. + Identifier *total_weight_ = nullptr; protected: using PatternAtom::PatternAtom; @@ -1866,6 +1872,7 @@ class EdgeAtom : public PatternAtom { }; save_lambda(filter_lambda_); save_lambda(weight_lambda_); + SavePointer(ar, total_weight_); } template <class TArchive> @@ -1894,6 +1901,7 @@ class EdgeAtom : public PatternAtom { }; load_lambda(filter_lambda_); load_lambda(weight_lambda_); + LoadPointer(ar, total_weight_); } template <class TArchive> diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index b4e30be37..aba2efdc3 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -480,8 +480,8 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( return edge; } - if (relationshipDetail->variable()) { - std::string variable = relationshipDetail->variable()->accept(this); + if (relationshipDetail->name) { + std::string variable = relationshipDetail->name->accept(this); edge->identifier_ = storage_.Create<Identifier>(variable); users_identifiers.insert(variable); } else { @@ -497,6 +497,10 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( auto relationshipLambdas = relationshipDetail->relationshipLambda(); if (variableExpansion) { + if (relationshipDetail->total_weight && + edge->type_ != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) + throw SemanticException( + "Variable for total weight is allowed only in wShortest"); auto visit_lambda = [this](auto *lambda) { EdgeAtom::Lambda edge_lambda; std::string traversed_edge_variable = @@ -510,6 +514,15 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( edge_lambda.expression = lambda->expression()->accept(this); return edge_lambda; }; + auto visit_total_weight = [&]() { + if (relationshipDetail->total_weight) { + std::string total_weight_name = + relationshipDetail->total_weight->accept(this); + edge->total_weight_ = storage_.Create<Identifier>(total_weight_name); + } else { + anonymous_identifiers.push_back(&edge->total_weight_); + } + }; switch (relationshipLambdas.size()) { case 0: if (edge->type_ == EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) @@ -524,6 +537,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( // For wShortest, the first (and required) lambda is used for weight // calculation. edge->weight_lambda_ = visit_lambda(relationshipLambdas[0]); + visit_total_weight(); // Add mandatory inner variables for filter lambda. anonymous_identifiers.push_back(&edge->filter_lambda_.inner_edge); anonymous_identifiers.push_back(&edge->filter_lambda_.inner_node); @@ -536,6 +550,7 @@ antlrcpp::Any CypherMainVisitor::visitRelationshipPattern( if (edge->type_ != EdgeAtom::Type::WEIGHTED_SHORTEST_PATH) throw SemanticException("Only one relationship lambda allowed"); edge->weight_lambda_ = visit_lambda(relationshipLambdas[0]); + visit_total_weight(); edge->filter_lambda_ = visit_lambda(relationshipLambdas[1]); break; default: diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index 8737f4ecf..7d0a35137 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -123,9 +123,9 @@ relationshipPattern : ( leftArrowHead SP? dash SP? ( relationshipDetail )? SP? d | ( dash SP? ( relationshipDetail )? SP? dash ) ; -relationshipDetail : '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? properties SP? ']' - | '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? relationshipLambda SP? (relationshipLambda SP?)? ']' - | '[' SP? ( variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? ( (properties SP?) | (relationshipLambda SP?) )* ']'; +relationshipDetail : '[' SP? ( name=variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? properties SP? ']' + | '[' SP? ( name=variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? relationshipLambda SP? ( total_weight=variable SP? )? (relationshipLambda SP?)? ']' + | '[' SP? ( name=variable SP? )? ( relationshipTypes SP? )? ( variableExpansion SP? )? (properties SP?)* ( relationshipLambda SP? total_weight=variable SP? )? (relationshipLambda SP?)? ']'; relationshipLambda: '(' SP? traversed_edge=variable SP? ',' SP? traversed_node=variable SP? '|' SP? expression SP? ')'; diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index 5413aedfe..2d5fa47fa 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -443,6 +443,14 @@ bool SymbolGenerator::PreVisit(EdgeAtom &edge_atom) { scope_.in_pattern_atom_identifier = true; edge_atom.identifier_->Accept(*this); scope_.in_pattern_atom_identifier = false; + if (edge_atom.total_weight_) { + if (HasSymbol(edge_atom.total_weight_->name_)) { + throw RedeclareVariableError(edge_atom.total_weight_->name_); + } + symbol_table_[*edge_atom.total_weight_] = GetOrCreateSymbol( + edge_atom.total_weight_->name_, edge_atom.total_weight_->user_declared_, + Symbol::Type::Number); + } return false; } diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index ce237e40e..244e3ed6d 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -1632,8 +1632,8 @@ TYPED_TEST(CypherMainVisitorTest, MatchVariableLambdaSymbols) { TYPED_TEST(CypherMainVisitorTest, MatchWShortestReturn) { TypeParam ast_generator( - "MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42) (e, n | true)]->() " - "RETURN r"); + "MATCH ()-[r:type1|type2 *wShortest 10 (we, wn | 42) total_weight " + "(e, n | true)]->() RETURN r"); auto *query = ast_generator.query_; ASSERT_TRUE(query->single_query_); auto *single_query = query->single_query_; @@ -1665,6 +1665,9 @@ TYPED_TEST(CypherMainVisitorTest, MatchWShortestReturn) { EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn"); EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_); CheckLiteral(ast_generator.context_, shortest->weight_lambda_.expression, 42); + ASSERT_TRUE(shortest->total_weight_); + EXPECT_EQ(shortest->total_weight_->name_, "total_weight"); + EXPECT_TRUE(shortest->total_weight_->user_declared_); } TYPED_TEST(CypherMainVisitorTest, MatchWShortestNoFilterReturn) { @@ -1699,6 +1702,8 @@ TYPED_TEST(CypherMainVisitorTest, MatchWShortestNoFilterReturn) { EXPECT_EQ(shortest->weight_lambda_.inner_node->name_, "wn"); EXPECT_TRUE(shortest->weight_lambda_.inner_node->user_declared_); CheckLiteral(ast_generator.context_, shortest->weight_lambda_.expression, 42); + ASSERT_TRUE(shortest->total_weight_); + EXPECT_FALSE(shortest->total_weight_->user_declared_); } diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index b0581bc53..09839c216 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -944,6 +944,7 @@ TEST_F(TestSymbolGenerator, MatchWShortestReturn) { shortest->weight_lambda_.inner_edge = IDENT("r"); shortest->weight_lambda_.inner_node = IDENT("n"); shortest->weight_lambda_.expression = r_weight; + shortest->total_weight_ = IDENT("total_weight"); } { shortest->filter_lambda_.inner_edge = IDENT("r"); @@ -954,8 +955,9 @@ TEST_F(TestSymbolGenerator, MatchWShortestReturn) { auto *query = QUERY(SINGLE_QUERY(MATCH(PATTERN(node_n, shortest, NODE("m"))), RETURN(ret_r, AS("r")))); query->Accept(symbol_generator); - // Symbols for pattern, `n`, `[r]`, (`r|`, `n|`)x2, `m` and `AS r`. - EXPECT_EQ(symbol_table.max_position(), 9); + // Symbols for pattern, `n`, `[r]`, `total_weight`, (`r|`, `n|`)x2, `m` and + // `AS r`. + EXPECT_EQ(symbol_table.max_position(), 10); EXPECT_EQ(symbol_table.at(*ret_r), symbol_table.at(*shortest->identifier_)); EXPECT_NE(symbol_table.at(*ret_r), symbol_table.at(*shortest->weight_lambda_.inner_edge));