From 32069c77a06998c59371288713afe84b9f1e3a83 Mon Sep 17 00:00:00 2001 From: Teon Banek <teon.banek@memgraph.io> Date: Thu, 3 Oct 2019 14:49:11 +0200 Subject: [PATCH] Add CallProcedure clause to Cypher Summary: This adds support for basic invocation to CALL clause of openCypher. The accepted CIP has a lot more features that are avaiable with CALL clause. https://github.com/opencypher/openCypher/blob/master/cip/1.accepted/CIP2015-06-24-call-procedures.adoc#appendix-procedure-naming-conventions Reviewers: mferencevic, ipaljak, llugovic Reviewed By: mferencevic, llugovic Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D2523 --- src/query/frontend/ast/ast.lcp | 43 +++++ src/query/frontend/ast/ast_visitor.hpp | 9 +- .../frontend/ast/cypher_main_visitor.cpp | 54 ++++++- .../frontend/ast/cypher_main_visitor.hpp | 5 + .../frontend/opencypher/grammar/Cypher.g4 | 11 ++ .../opencypher/grammar/CypherLexer.g4 | 2 + .../frontend/semantic/required_privileges.cpp | 4 + .../frontend/semantic/symbol_generator.cpp | 17 ++ .../frontend/semantic/symbol_generator.hpp | 2 + .../frontend/stripped_lexer_constants.hpp | 2 +- src/query/plan/rule_based_planner.hpp | 4 +- tests/unit/cypher_main_visitor.cpp | 150 ++++++++++++++++++ tests/unit/query_semantic.cpp | 60 +++++++ 13 files changed, 355 insertions(+), 8 deletions(-) diff --git a/src/query/frontend/ast/ast.lcp b/src/query/frontend/ast/ast.lcp index 56fef5fa5..2412b38f2 100644 --- a/src/query/frontend/ast/ast.lcp +++ b/src/query/frontend/ast/ast.lcp @@ -1586,6 +1586,49 @@ cpp<# (:serialize (:slk)) (:clone)) +(lcp:define-class call-procedure (clause) + ((procedure-name "std::string" :scope :public) + (arguments "std::vector<Expression *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Expression")) + (result-fields "std::vector<std::string>" :scope :public) + (result-identifiers "std::vector<Identifier *>" + :scope :public + :slk-save #'slk-save-ast-vector + :slk-load (slk-load-ast-vector "Identifier"))) + (:public + #>cpp + CallProcedure() = default; + + bool Accept(HierarchicalTreeVisitor &visitor) override { + if (visitor.PreVisit(*this)) { + bool cont = true; + for (auto &arg : arguments_) { + if (!arg->Accept(visitor)) { + cont = false; + break; + } + } + if (cont) { + for (auto &ident : result_identifiers_) { + if (!ident->Accept(visitor)) { + cont = false; + break; + } + } + } + } + return visitor.PostVisit(*this); + } + cpp<#) + (:private + #>cpp + friend class AstStorage; + cpp<#) + (:serialize (:slk)) + (:clone)) + (lcp:define-class match (clause) ((patterns "std::vector<Pattern *>" :scope :public diff --git a/src/query/frontend/ast/ast_visitor.hpp b/src/query/frontend/ast/ast_visitor.hpp index aa2625df5..eba58818e 100644 --- a/src/query/frontend/ast/ast_visitor.hpp +++ b/src/query/frontend/ast/ast_visitor.hpp @@ -20,6 +20,7 @@ class Extract; class All; class Single; class ParameterLookup; +class CallProcedure; class Create; class Match; class Return; @@ -78,10 +79,10 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor< GreaterEqualOperator, InListOperator, SubscriptOperator, ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator, IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest, - Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Create, - Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where, - SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge, - Unwind, RegexMatch>; + Aggregation, Function, Reduce, Coalesce, Extract, All, Single, + CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom, + Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty, + RemoveLabels, Merge, Unwind, RegexMatch>; using TreeLeafVisitor = ::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>; diff --git a/src/query/frontend/ast/cypher_main_visitor.cpp b/src/query/frontend/ast/cypher_main_visitor.cpp index c69c26770..c9c8a3864 100644 --- a/src/query/frontend/ast/cypher_main_visitor.cpp +++ b/src/query/frontend/ast/cypher_main_visitor.cpp @@ -217,10 +217,16 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery( bool has_update = false; bool has_return = false; bool has_optional_match = false; + bool has_call_procedure = false; for (Clause *clause : single_query->clauses_) { const auto &clause_type = clause->GetTypeInfo(); - if (utils::IsSubtype(clause_type, Unwind::kType)) { + if (utils::IsSubtype(clause_type, CallProcedure::kType)) { + if (has_return) { + throw SemanticException("CALL can't be put after RETURN clause."); + } + has_call_procedure = true; + } else if (utils::IsSubtype(clause_type, Unwind::kType)) { if (has_update || has_return) { throw SemanticException( "UNWIND can't be put after RETURN clause or after an update."); @@ -261,7 +267,9 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery( DLOG(FATAL) << "Can't happen"; } } - if (!has_update && !has_return) { + bool is_standalone_call_procedure = + has_call_procedure && single_query->clauses_.size() == 1U; + if (!has_update && !has_return && !is_standalone_call_procedure) { throw SemanticException( "Query should either create or update something, or return results!"); } @@ -314,6 +322,10 @@ antlrcpp::Any CypherMainVisitor::visitClause( if (ctx->unwind()) { return static_cast<Clause *>(ctx->unwind()->accept(this).as<Unwind *>()); } + if (ctx->callProcedure()) { + return static_cast<Clause *>( + ctx->callProcedure()->accept(this).as<CallProcedure *>()); + } // TODO: implement other clauses. throw utils::NotYetImplemented("clause '{}'", ctx->getText()); return 0; @@ -337,6 +349,44 @@ antlrcpp::Any CypherMainVisitor::visitCreate( return create; } +antlrcpp::Any CypherMainVisitor::visitCallProcedure( + MemgraphCypher::CallProcedureContext *ctx) { + auto *call_proc = storage_->Create<CallProcedure>(); + CHECK(!ctx->procedureName()->symbolicName().empty()); + std::vector<std::string> procedure_subnames; + procedure_subnames.reserve(ctx->procedureName()->symbolicName().size()); + for (auto *subname : ctx->procedureName()->symbolicName()) { + procedure_subnames.emplace_back(subname->accept(this).as<std::string>()); + } + utils::Join(&call_proc->procedure_name_, procedure_subnames, "."); + call_proc->arguments_.reserve(ctx->expression().size()); + for (auto *expr : ctx->expression()) { + call_proc->arguments_.push_back(expr->accept(this)); + } + auto *yield_ctx = ctx->yieldProcedureResults(); + if (!yield_ctx) { + // TODO: Standalone CallProcedure clause may omit YIELD only if the function + // never returns anything. + return call_proc; + } + call_proc->result_fields_.reserve(yield_ctx->procedureResult().size()); + call_proc->result_identifiers_.reserve(yield_ctx->procedureResult().size()); + for (auto *result : yield_ctx->procedureResult()) { + CHECK(result->variable().size() == 1 || result->variable().size() == 2); + call_proc->result_fields_.push_back( + result->variable()[0]->accept(this).as<std::string>()); + std::string result_alias; + if (result->variable().size() == 2) { + result_alias = result->variable()[1]->accept(this).as<std::string>(); + } else { + result_alias = result->variable()[0]->accept(this).as<std::string>(); + } + call_proc->result_identifiers_.push_back( + storage_->Create<Identifier>(result_alias)); + } + return call_proc; +} + /** * @return std::string */ diff --git a/src/query/frontend/ast/cypher_main_visitor.hpp b/src/query/frontend/ast/cypher_main_visitor.hpp index 24842f519..dc12edeaf 100644 --- a/src/query/frontend/ast/cypher_main_visitor.hpp +++ b/src/query/frontend/ast/cypher_main_visitor.hpp @@ -214,6 +214,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor { */ antlrcpp::Any visitCreate(MemgraphCypher::CreateContext *ctx) override; + /** + * @return CallProcedure* + */ + antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override; + /** * @return std::string */ diff --git a/src/query/frontend/opencypher/grammar/Cypher.g4 b/src/query/frontend/opencypher/grammar/Cypher.g4 index 36f089628..d34cdf5d5 100644 --- a/src/query/frontend/opencypher/grammar/Cypher.g4 +++ b/src/query/frontend/opencypher/grammar/Cypher.g4 @@ -73,6 +73,7 @@ clause : cypherMatch | remove | with | cypherReturn + | callProcedure ; cypherMatch : OPTIONAL? MATCH pattern where? ; @@ -107,6 +108,14 @@ with : WITH ( DISTINCT )? returnBody ( where )? ; cypherReturn : RETURN ( DISTINCT )? returnBody ; +callProcedure : CALL procedureName '(' ( expression ( ',' expression )* )? ')' ( yieldProcedureResults )? ; + +procedureName : symbolicName ( '.' symbolicName )* ; + +yieldProcedureResults : YIELD ( procedureResult ( ',' procedureResult )* ) ; + +procedureResult : ( variable AS variable ) | variable ; + returnBody : returnItems ( order )? ( skip )? ( limit )? ; returnItems : ( '*' ( ',' returnItem )* ) @@ -312,6 +321,7 @@ cypherKeyword : ALL | ASSERT | BFS | BY + | CALL | CASE | CONSTRAINT | CONTAINS @@ -366,6 +376,7 @@ cypherKeyword : ALL | WITH | WSHORTEST | XOR + | YIELD ; symbolicName : UnescapedSymbolicName diff --git a/src/query/frontend/opencypher/grammar/CypherLexer.g4 b/src/query/frontend/opencypher/grammar/CypherLexer.g4 index 214326b88..5dd5a3017 100644 --- a/src/query/frontend/opencypher/grammar/CypherLexer.g4 +++ b/src/query/frontend/opencypher/grammar/CypherLexer.g4 @@ -77,6 +77,7 @@ ASCENDING : A S C E N D I N G ; ASSERT : A S S E R T ; BFS : B F S ; BY : B Y ; +CALL : C A L L ; CASE : C A S E ; COALESCE : C O A L E S C E ; CONSTRAINT : C O N S T R A I N T ; @@ -134,6 +135,7 @@ WHERE : W H E R E ; WITH : W I T H ; WSHORTEST : W S H O R T E S T ; XOR : X O R ; +YIELD : Y I E L D ; /* Double and single quoted string literals. */ StringLiteral : '"' ( ~[\\"] | EscapeSequence )* '"' diff --git a/src/query/frontend/semantic/required_privileges.cpp b/src/query/frontend/semantic/required_privileges.cpp index e090e697d..1f5097095 100644 --- a/src/query/frontend/semantic/required_privileges.cpp +++ b/src/query/frontend/semantic/required_privileges.cpp @@ -68,6 +68,10 @@ class PrivilegeExtractor : public QueryVisitor<void>, AddPrivilege(AuthQuery::Privilege::CREATE); return false; } + bool PreVisit(CallProcedure &) override { + // TODO: Corresponding privilege + return false; + } bool PreVisit(Delete &) override { AddPrivilege(AuthQuery::Privilege::DELETE); return false; diff --git a/src/query/frontend/semantic/symbol_generator.cpp b/src/query/frontend/semantic/symbol_generator.cpp index e604350dd..10a082b5d 100644 --- a/src/query/frontend/semantic/symbol_generator.cpp +++ b/src/query/frontend/semantic/symbol_generator.cpp @@ -156,6 +156,23 @@ bool SymbolGenerator::PostVisit(Create &) { return true; } +bool SymbolGenerator::PreVisit(CallProcedure &call_proc) { + for (auto *expr : call_proc.arguments_) { + expr->Accept(*this); + } + return false; +} + +bool SymbolGenerator::PostVisit(CallProcedure &call_proc) { + for (auto *ident : call_proc.result_identifiers_) { + if (HasSymbol(ident->name_)) { + throw RedeclareVariableError(ident->name_); + } + ident->MapTo(CreateSymbol(ident->name_, true)); + } + return true; +} + bool SymbolGenerator::PreVisit(Return &ret) { scope_.in_return = true; VisitReturnBody(ret.body_); diff --git a/src/query/frontend/semantic/symbol_generator.hpp b/src/query/frontend/semantic/symbol_generator.hpp index 6932e09b1..c9718c749 100644 --- a/src/query/frontend/semantic/symbol_generator.hpp +++ b/src/query/frontend/semantic/symbol_generator.hpp @@ -35,6 +35,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor { // Clauses bool PreVisit(Create &) override; bool PostVisit(Create &) override; + bool PreVisit(CallProcedure &) override; + bool PostVisit(CallProcedure &) override; bool PreVisit(Return &) override; bool PostVisit(Return &) override; bool PreVisit(With &) override; diff --git a/src/query/frontend/stripped_lexer_constants.hpp b/src/query/frontend/stripped_lexer_constants.hpp index 7d99ca375..d30520e41 100644 --- a/src/query/frontend/stripped_lexer_constants.hpp +++ b/src/query/frontend/stripped_lexer_constants.hpp @@ -91,7 +91,7 @@ const trie::Trie kKeywords = { "show", "stats", "unique", "explain", "profile", "storage", "index", "info", "exists", "assert", "constraint", - "node", "key", "dump", "database"}; + "node", "key", "dump", "database", "call", "yield"}; // Unicode codepoints that are allowed at the start of the unescaped name. const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(std::string( diff --git a/src/query/plan/rule_based_planner.hpp b/src/query/plan/rule_based_planner.hpp index 59bc895f4..3fc82eb63 100644 --- a/src/query/plan/rule_based_planner.hpp +++ b/src/query/plan/rule_based_planner.hpp @@ -210,7 +210,9 @@ class RuleBasedPlanner { std::move(input_op), unwind->named_expression_->expression_, symbol); } else { - throw utils::NotYetImplemented("clause conversion to operator(s)"); + throw utils::NotYetImplemented( + "clause '{}' conversion to operator(s)", + clause->GetTypeInfo().name); } } } diff --git a/tests/unit/cypher_main_visitor.cpp b/tests/unit/cypher_main_visitor.cpp index 34e8e474a..fae89c47a 100644 --- a/tests/unit/cypher_main_visitor.cpp +++ b/tests/unit/cypher_main_visitor.cpp @@ -2669,4 +2669,154 @@ TEST_P(CypherMainVisitorTest, DumpDatabase) { ASSERT_TRUE(query); } +TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL proc.with.dots()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL `proc-with-dashes`()")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "proc-with-dashes"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_TRUE(call_proc->result_identifiers_.empty()); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery( + "CALL proc() YIELD fst, `field-with-dashes`, last_field")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "proc"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_EQ(call_proc->result_fields_.size(), 3U); + ASSERT_EQ(call_proc->result_identifiers_.size(), + call_proc->result_fields_.size()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> expected_names{"fst", "field-with-dashes", + "last_field"}; + ASSERT_EQ(identifier_names, expected_names); + ASSERT_EQ(identifier_names, call_proc->result_fields_); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL proc() YIELD fst AS res1, snd AS " + "`result-with-dashes`, thrd AS last_result")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "proc"); + ASSERT_TRUE(call_proc->arguments_.empty()); + ASSERT_EQ(call_proc->result_fields_.size(), 3U); + ASSERT_EQ(call_proc->result_identifiers_.size(), + call_proc->result_fields_.size()); + std::vector<std::string> identifier_names; + identifier_names.reserve(call_proc->result_identifiers_.size()); + for (const auto *identifier : call_proc->result_identifiers_) { + ASSERT_TRUE(identifier->user_declared_); + identifier_names.push_back(identifier->name_); + } + std::vector<std::string> aliased_names{"res1", "result-with-dashes", + "last_result"}; + ASSERT_EQ(identifier_names, aliased_names); + std::vector<std::string> field_names{"fst", "snd", "thrd"}; + ASSERT_EQ(call_proc->result_fields_, field_names); +} + +TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) { + auto &ast_generator = *GetParam(); + auto *query = dynamic_cast<CypherQuery *>( + ast_generator.ParseQuery("CALL proc(0, 1, 2)")); + ASSERT_TRUE(query); + ASSERT_TRUE(query->single_query_); + auto *single_query = query->single_query_; + ASSERT_EQ(single_query->clauses_.size(), 1U); + auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); + ASSERT_TRUE(call_proc); + ASSERT_EQ(call_proc->procedure_name_, "proc"); + ASSERT_TRUE(call_proc->result_fields_.empty()); + ASSERT_EQ(call_proc->result_identifiers_.size(), + call_proc->result_fields_.size()); + ASSERT_EQ(call_proc->arguments_.size(), 3U); + for (int64_t i = 0; i < 3; ++i) { + ast_generator.CheckLiteral(call_proc->arguments_[i], i); + } +} + +TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) { + auto &ast_generator = *GetParam(); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc-with-dashes()"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field-with-dashes"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field.with.dots"), + SyntaxException); + ASSERT_THROW( + ast_generator.ParseQuery("CALL proc() yield res AS result-with-dashes"), + SyntaxException); + ASSERT_THROW( + ast_generator.ParseQuery("CALL proc() yield res AS result.with.dots"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("WITH 42 AS x CALL not_standalone(x)"), + SemanticException); + ASSERT_THROW(ast_generator.ParseQuery("CALL procedure() YIELD"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD"), + SyntaxException); + ASSERT_THROW( + ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD res"), + SyntaxException); + ASSERT_THROW( + ast_generator.ParseQuery("RETURN 42 AS x CALL procedure() YIELD res"), + SemanticException); + // TODO: Implement support for the following syntax. These are defined in + // Neo4j and accepted in openCypher CIP. + ASSERT_THROW(ast_generator.ParseQuery("CALL proc"), SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc RETURN 42"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD *"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD * RETURN *"), + SyntaxException); + ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42"), + SyntaxException); + ASSERT_THROW( + ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42 RETURN *"), + SyntaxException); +} + } // namespace diff --git a/tests/unit/query_semantic.cpp b/tests/unit/query_semantic.cpp index 87abd69d7..42b837024 100644 --- a/tests/unit/query_semantic.cpp +++ b/tests/unit/query_semantic.cpp @@ -1129,3 +1129,63 @@ TEST_F(TestSymbolGenerator, MatchUnion) { auto symbol_table = query::MakeSymbolTable(query); EXPECT_EQ(symbol_table.max_position(), 8); } + +TEST_F(TestSymbolGenerator, CallProcedureYield) { + // WITH 1 AS x CALL proc(x) YIELD x AS y RETURN x, y + auto call = storage.Create<CallProcedure>(); + call->procedure_name_ = "proc"; + auto *arg_x = IDENT("x"); + call->arguments_.push_back(arg_x); + call->result_fields_.emplace_back("x"); + call->result_identifiers_.push_back(IDENT("y")); + auto *as_x = AS("x"); + auto *ret = RETURN("x", "y"); + auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), as_x), call, ret)); + auto symbol_table = query::MakeSymbolTable(query); + EXPECT_EQ(symbol_table.max_position(), 4); + const auto &sym_x = symbol_table.at(*as_x); + const auto &sym_y = symbol_table.at(*call->result_identifiers_.back()); + EXPECT_EQ(symbol_table.at(*arg_x), sym_x); + auto *ret_x = + dynamic_cast<Identifier *>(ret->body_.named_expressions[0]->expression_); + ASSERT_TRUE(ret_x); + auto *ret_y = + dynamic_cast<Identifier *>(ret->body_.named_expressions[1]->expression_); + ASSERT_TRUE(ret_y); + EXPECT_EQ(symbol_table.at(*ret_x), sym_x); + EXPECT_EQ(symbol_table.at(*ret_y), sym_y); + EXPECT_NE(symbol_table.at(*ret->body_.named_expressions[0]), sym_x); + EXPECT_NE(symbol_table.at(*ret->body_.named_expressions[1]), sym_y); +} + +TEST_F(TestSymbolGenerator, CallProcedureShadowingYield) { + // WITH 1 AS x CALL proc() YIELD x RETURN 42 AS res + auto call = storage.Create<CallProcedure>(); + call->procedure_name_ = "proc"; + call->result_fields_.emplace_back("x"); + call->result_identifiers_.push_back(IDENT("x")); + auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), AS("x")), call, + RETURN(LITERAL(42), AS("res")))); + EXPECT_THROW(query::MakeSymbolTable(query), SemanticException); +} + +TEST_F(TestSymbolGenerator, CallProcedureShadowingYieldAlias) { + // WITH 1 AS x CALL proc() YIELD y AS x RETURN 42 AS res + auto call = storage.Create<CallProcedure>(); + call->procedure_name_ = "proc"; + call->result_fields_.emplace_back("y"); + call->result_identifiers_.push_back(IDENT("x")); + auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), AS("x")), call, + RETURN(LITERAL(42), AS("res")))); + EXPECT_THROW(query::MakeSymbolTable(query), SemanticException); +} + +TEST_F(TestSymbolGenerator, CallProcedureUnboundArgument) { + // CALL proc(unbound) + auto call = storage.Create<CallProcedure>(); + call->procedure_name_ = "proc"; + call->arguments_.push_back(IDENT("unbound")); + auto query = QUERY(SINGLE_QUERY(call)); + EXPECT_THROW(query::MakeSymbolTable(query), SemanticException); +} +