Add CallProcedure clause to Cypher

Summary:
This adds support for basic invocation to CALL clause of openCypher. The
accepted CIP has a lot more features that are avaiable with CALL clause.

https://github.com/opencypher/openCypher/blob/master/cip/1.accepted/CIP2015-06-24-call-procedures.adoc#appendix-procedure-naming-conventions

Reviewers: mferencevic, ipaljak, llugovic

Reviewed By: mferencevic, llugovic

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2523
This commit is contained in:
Teon Banek 2019-10-03 14:49:11 +02:00
parent 4248b140d4
commit 32069c77a0
13 changed files with 355 additions and 8 deletions

View File

@ -1586,6 +1586,49 @@ cpp<#
(:serialize (:slk))
(:clone))
(lcp:define-class call-procedure (clause)
((procedure-name "std::string" :scope :public)
(arguments "std::vector<Expression *>"
:scope :public
:slk-save #'slk-save-ast-vector
:slk-load (slk-load-ast-vector "Expression"))
(result-fields "std::vector<std::string>" :scope :public)
(result-identifiers "std::vector<Identifier *>"
:scope :public
:slk-save #'slk-save-ast-vector
:slk-load (slk-load-ast-vector "Identifier")))
(:public
#>cpp
CallProcedure() = default;
bool Accept(HierarchicalTreeVisitor &visitor) override {
if (visitor.PreVisit(*this)) {
bool cont = true;
for (auto &arg : arguments_) {
if (!arg->Accept(visitor)) {
cont = false;
break;
}
}
if (cont) {
for (auto &ident : result_identifiers_) {
if (!ident->Accept(visitor)) {
cont = false;
break;
}
}
}
}
return visitor.PostVisit(*this);
}
cpp<#)
(:private
#>cpp
friend class AstStorage;
cpp<#)
(:serialize (:slk))
(:clone))
(lcp:define-class match (clause)
((patterns "std::vector<Pattern *>"
:scope :public

View File

@ -20,6 +20,7 @@ class Extract;
class All;
class Single;
class ParameterLookup;
class CallProcedure;
class Create;
class Match;
class Return;
@ -78,10 +79,10 @@ using TreeCompositeVisitor = ::utils::CompositeVisitor<
GreaterEqualOperator, InListOperator, SubscriptOperator,
ListSlicingOperator, IfOperator, UnaryPlusOperator, UnaryMinusOperator,
IsNullOperator, ListLiteral, MapLiteral, PropertyLookup, LabelsTest,
Aggregation, Function, Reduce, Coalesce, Extract, All, Single, Create,
Match, Return, With, Pattern, NodeAtom, EdgeAtom, Delete, Where,
SetProperty, SetProperties, SetLabels, RemoveProperty, RemoveLabels, Merge,
Unwind, RegexMatch>;
Aggregation, Function, Reduce, Coalesce, Extract, All, Single,
CallProcedure, Create, Match, Return, With, Pattern, NodeAtom, EdgeAtom,
Delete, Where, SetProperty, SetProperties, SetLabels, RemoveProperty,
RemoveLabels, Merge, Unwind, RegexMatch>;
using TreeLeafVisitor =
::utils::LeafVisitor<Identifier, PrimitiveLiteral, ParameterLookup>;

View File

@ -217,10 +217,16 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
bool has_update = false;
bool has_return = false;
bool has_optional_match = false;
bool has_call_procedure = false;
for (Clause *clause : single_query->clauses_) {
const auto &clause_type = clause->GetTypeInfo();
if (utils::IsSubtype(clause_type, Unwind::kType)) {
if (utils::IsSubtype(clause_type, CallProcedure::kType)) {
if (has_return) {
throw SemanticException("CALL can't be put after RETURN clause.");
}
has_call_procedure = true;
} else if (utils::IsSubtype(clause_type, Unwind::kType)) {
if (has_update || has_return) {
throw SemanticException(
"UNWIND can't be put after RETURN clause or after an update.");
@ -261,7 +267,9 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(
DLOG(FATAL) << "Can't happen";
}
}
if (!has_update && !has_return) {
bool is_standalone_call_procedure =
has_call_procedure && single_query->clauses_.size() == 1U;
if (!has_update && !has_return && !is_standalone_call_procedure) {
throw SemanticException(
"Query should either create or update something, or return results!");
}
@ -314,6 +322,10 @@ antlrcpp::Any CypherMainVisitor::visitClause(
if (ctx->unwind()) {
return static_cast<Clause *>(ctx->unwind()->accept(this).as<Unwind *>());
}
if (ctx->callProcedure()) {
return static_cast<Clause *>(
ctx->callProcedure()->accept(this).as<CallProcedure *>());
}
// TODO: implement other clauses.
throw utils::NotYetImplemented("clause '{}'", ctx->getText());
return 0;
@ -337,6 +349,44 @@ antlrcpp::Any CypherMainVisitor::visitCreate(
return create;
}
antlrcpp::Any CypherMainVisitor::visitCallProcedure(
MemgraphCypher::CallProcedureContext *ctx) {
auto *call_proc = storage_->Create<CallProcedure>();
CHECK(!ctx->procedureName()->symbolicName().empty());
std::vector<std::string> procedure_subnames;
procedure_subnames.reserve(ctx->procedureName()->symbolicName().size());
for (auto *subname : ctx->procedureName()->symbolicName()) {
procedure_subnames.emplace_back(subname->accept(this).as<std::string>());
}
utils::Join(&call_proc->procedure_name_, procedure_subnames, ".");
call_proc->arguments_.reserve(ctx->expression().size());
for (auto *expr : ctx->expression()) {
call_proc->arguments_.push_back(expr->accept(this));
}
auto *yield_ctx = ctx->yieldProcedureResults();
if (!yield_ctx) {
// TODO: Standalone CallProcedure clause may omit YIELD only if the function
// never returns anything.
return call_proc;
}
call_proc->result_fields_.reserve(yield_ctx->procedureResult().size());
call_proc->result_identifiers_.reserve(yield_ctx->procedureResult().size());
for (auto *result : yield_ctx->procedureResult()) {
CHECK(result->variable().size() == 1 || result->variable().size() == 2);
call_proc->result_fields_.push_back(
result->variable()[0]->accept(this).as<std::string>());
std::string result_alias;
if (result->variable().size() == 2) {
result_alias = result->variable()[1]->accept(this).as<std::string>();
} else {
result_alias = result->variable()[0]->accept(this).as<std::string>();
}
call_proc->result_identifiers_.push_back(
storage_->Create<Identifier>(result_alias));
}
return call_proc;
}
/**
* @return std::string
*/

View File

@ -214,6 +214,11 @@ class CypherMainVisitor : public antlropencypher::MemgraphCypherBaseVisitor {
*/
antlrcpp::Any visitCreate(MemgraphCypher::CreateContext *ctx) override;
/**
* @return CallProcedure*
*/
antlrcpp::Any visitCallProcedure(MemgraphCypher::CallProcedureContext *ctx) override;
/**
* @return std::string
*/

View File

@ -73,6 +73,7 @@ clause : cypherMatch
| remove
| with
| cypherReturn
| callProcedure
;
cypherMatch : OPTIONAL? MATCH pattern where? ;
@ -107,6 +108,14 @@ with : WITH ( DISTINCT )? returnBody ( where )? ;
cypherReturn : RETURN ( DISTINCT )? returnBody ;
callProcedure : CALL procedureName '(' ( expression ( ',' expression )* )? ')' ( yieldProcedureResults )? ;
procedureName : symbolicName ( '.' symbolicName )* ;
yieldProcedureResults : YIELD ( procedureResult ( ',' procedureResult )* ) ;
procedureResult : ( variable AS variable ) | variable ;
returnBody : returnItems ( order )? ( skip )? ( limit )? ;
returnItems : ( '*' ( ',' returnItem )* )
@ -312,6 +321,7 @@ cypherKeyword : ALL
| ASSERT
| BFS
| BY
| CALL
| CASE
| CONSTRAINT
| CONTAINS
@ -366,6 +376,7 @@ cypherKeyword : ALL
| WITH
| WSHORTEST
| XOR
| YIELD
;
symbolicName : UnescapedSymbolicName

View File

@ -77,6 +77,7 @@ ASCENDING : A S C E N D I N G ;
ASSERT : A S S E R T ;
BFS : B F S ;
BY : B Y ;
CALL : C A L L ;
CASE : C A S E ;
COALESCE : C O A L E S C E ;
CONSTRAINT : C O N S T R A I N T ;
@ -134,6 +135,7 @@ WHERE : W H E R E ;
WITH : W I T H ;
WSHORTEST : W S H O R T E S T ;
XOR : X O R ;
YIELD : Y I E L D ;
/* Double and single quoted string literals. */
StringLiteral : '"' ( ~[\\"] | EscapeSequence )* '"'

View File

@ -68,6 +68,10 @@ class PrivilegeExtractor : public QueryVisitor<void>,
AddPrivilege(AuthQuery::Privilege::CREATE);
return false;
}
bool PreVisit(CallProcedure &) override {
// TODO: Corresponding privilege
return false;
}
bool PreVisit(Delete &) override {
AddPrivilege(AuthQuery::Privilege::DELETE);
return false;

View File

@ -156,6 +156,23 @@ bool SymbolGenerator::PostVisit(Create &) {
return true;
}
bool SymbolGenerator::PreVisit(CallProcedure &call_proc) {
for (auto *expr : call_proc.arguments_) {
expr->Accept(*this);
}
return false;
}
bool SymbolGenerator::PostVisit(CallProcedure &call_proc) {
for (auto *ident : call_proc.result_identifiers_) {
if (HasSymbol(ident->name_)) {
throw RedeclareVariableError(ident->name_);
}
ident->MapTo(CreateSymbol(ident->name_, true));
}
return true;
}
bool SymbolGenerator::PreVisit(Return &ret) {
scope_.in_return = true;
VisitReturnBody(ret.body_);

View File

@ -35,6 +35,8 @@ class SymbolGenerator : public HierarchicalTreeVisitor {
// Clauses
bool PreVisit(Create &) override;
bool PostVisit(Create &) override;
bool PreVisit(CallProcedure &) override;
bool PostVisit(CallProcedure &) override;
bool PreVisit(Return &) override;
bool PostVisit(Return &) override;
bool PreVisit(With &) override;

View File

@ -91,7 +91,7 @@ const trie::Trie kKeywords = {
"show", "stats",
"unique", "explain", "profile",
"storage", "index", "info", "exists", "assert", "constraint",
"node", "key", "dump", "database"};
"node", "key", "dump", "database", "call", "yield"};
// Unicode codepoints that are allowed at the start of the unescaped name.
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(std::string(

View File

@ -210,7 +210,9 @@ class RuleBasedPlanner {
std::move(input_op), unwind->named_expression_->expression_,
symbol);
} else {
throw utils::NotYetImplemented("clause conversion to operator(s)");
throw utils::NotYetImplemented(
"clause '{}' conversion to operator(s)",
clause->GetTypeInfo().name);
}
}
}

View File

@ -2669,4 +2669,154 @@ TEST_P(CypherMainVisitorTest, DumpDatabase) {
ASSERT_TRUE(query);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL proc.with.dots()"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots");
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_TRUE(call_proc->result_fields_.empty());
ASSERT_TRUE(call_proc->result_identifiers_.empty());
}
TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL `proc-with-dashes`()"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc-with-dashes");
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_TRUE(call_proc->result_fields_.empty());
ASSERT_TRUE(call_proc->result_identifiers_.empty());
}
TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(
"CALL proc() YIELD fst, `field-with-dashes`, last_field"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc");
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_EQ(call_proc->result_fields_.size(), 3U);
ASSERT_EQ(call_proc->result_identifiers_.size(),
call_proc->result_fields_.size());
std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size());
for (const auto *identifier : call_proc->result_identifiers_) {
ASSERT_TRUE(identifier->user_declared_);
identifier_names.push_back(identifier->name_);
}
std::vector<std::string> expected_names{"fst", "field-with-dashes",
"last_field"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL proc() YIELD fst AS res1, snd AS "
"`result-with-dashes`, thrd AS last_result"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc");
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_EQ(call_proc->result_fields_.size(), 3U);
ASSERT_EQ(call_proc->result_identifiers_.size(),
call_proc->result_fields_.size());
std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size());
for (const auto *identifier : call_proc->result_identifiers_) {
ASSERT_TRUE(identifier->user_declared_);
identifier_names.push_back(identifier->name_);
}
std::vector<std::string> aliased_names{"res1", "result-with-dashes",
"last_result"};
ASSERT_EQ(identifier_names, aliased_names);
std::vector<std::string> field_names{"fst", "snd", "thrd"};
ASSERT_EQ(call_proc->result_fields_, field_names);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL proc(0, 1, 2)"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc");
ASSERT_TRUE(call_proc->result_fields_.empty());
ASSERT_EQ(call_proc->result_identifiers_.size(),
call_proc->result_fields_.size());
ASSERT_EQ(call_proc->arguments_.size(), 3U);
for (int64_t i = 0; i < 3; ++i) {
ast_generator.CheckLiteral(call_proc->arguments_[i], i);
}
}
TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) {
auto &ast_generator = *GetParam();
ASSERT_THROW(ast_generator.ParseQuery("CALL proc-with-dashes()"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field-with-dashes"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("CALL proc() yield field.with.dots"),
SyntaxException);
ASSERT_THROW(
ast_generator.ParseQuery("CALL proc() yield res AS result-with-dashes"),
SyntaxException);
ASSERT_THROW(
ast_generator.ParseQuery("CALL proc() yield res AS result.with.dots"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("WITH 42 AS x CALL not_standalone(x)"),
SemanticException);
ASSERT_THROW(ast_generator.ParseQuery("CALL procedure() YIELD"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD"),
SyntaxException);
ASSERT_THROW(
ast_generator.ParseQuery("RETURN 42, CALL procedure() YIELD res"),
SyntaxException);
ASSERT_THROW(
ast_generator.ParseQuery("RETURN 42 AS x CALL procedure() YIELD res"),
SemanticException);
// TODO: Implement support for the following syntax. These are defined in
// Neo4j and accepted in openCypher CIP.
ASSERT_THROW(ast_generator.ParseQuery("CALL proc"), SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("CALL proc RETURN 42"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD *"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD * RETURN *"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42"),
SyntaxException);
ASSERT_THROW(
ast_generator.ParseQuery("CALL proc() YIELD res WHERE res > 42 RETURN *"),
SyntaxException);
}
} // namespace

View File

@ -1129,3 +1129,63 @@ TEST_F(TestSymbolGenerator, MatchUnion) {
auto symbol_table = query::MakeSymbolTable(query);
EXPECT_EQ(symbol_table.max_position(), 8);
}
TEST_F(TestSymbolGenerator, CallProcedureYield) {
// WITH 1 AS x CALL proc(x) YIELD x AS y RETURN x, y
auto call = storage.Create<CallProcedure>();
call->procedure_name_ = "proc";
auto *arg_x = IDENT("x");
call->arguments_.push_back(arg_x);
call->result_fields_.emplace_back("x");
call->result_identifiers_.push_back(IDENT("y"));
auto *as_x = AS("x");
auto *ret = RETURN("x", "y");
auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), as_x), call, ret));
auto symbol_table = query::MakeSymbolTable(query);
EXPECT_EQ(symbol_table.max_position(), 4);
const auto &sym_x = symbol_table.at(*as_x);
const auto &sym_y = symbol_table.at(*call->result_identifiers_.back());
EXPECT_EQ(symbol_table.at(*arg_x), sym_x);
auto *ret_x =
dynamic_cast<Identifier *>(ret->body_.named_expressions[0]->expression_);
ASSERT_TRUE(ret_x);
auto *ret_y =
dynamic_cast<Identifier *>(ret->body_.named_expressions[1]->expression_);
ASSERT_TRUE(ret_y);
EXPECT_EQ(symbol_table.at(*ret_x), sym_x);
EXPECT_EQ(symbol_table.at(*ret_y), sym_y);
EXPECT_NE(symbol_table.at(*ret->body_.named_expressions[0]), sym_x);
EXPECT_NE(symbol_table.at(*ret->body_.named_expressions[1]), sym_y);
}
TEST_F(TestSymbolGenerator, CallProcedureShadowingYield) {
// WITH 1 AS x CALL proc() YIELD x RETURN 42 AS res
auto call = storage.Create<CallProcedure>();
call->procedure_name_ = "proc";
call->result_fields_.emplace_back("x");
call->result_identifiers_.push_back(IDENT("x"));
auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), AS("x")), call,
RETURN(LITERAL(42), AS("res"))));
EXPECT_THROW(query::MakeSymbolTable(query), SemanticException);
}
TEST_F(TestSymbolGenerator, CallProcedureShadowingYieldAlias) {
// WITH 1 AS x CALL proc() YIELD y AS x RETURN 42 AS res
auto call = storage.Create<CallProcedure>();
call->procedure_name_ = "proc";
call->result_fields_.emplace_back("y");
call->result_identifiers_.push_back(IDENT("x"));
auto query = QUERY(SINGLE_QUERY(WITH(LITERAL(1), AS("x")), call,
RETURN(LITERAL(42), AS("res"))));
EXPECT_THROW(query::MakeSymbolTable(query), SemanticException);
}
TEST_F(TestSymbolGenerator, CallProcedureUnboundArgument) {
// CALL proc(unbound)
auto call = storage.Create<CallProcedure>();
call->procedure_name_ = "proc";
call->arguments_.push_back(IDENT("unbound"));
auto query = QUERY(SINGLE_QUERY(call));
EXPECT_THROW(query::MakeSymbolTable(query), SemanticException);
}