Handle write procedures in queries and update docs of properties iterator (#238)

This commit is contained in:
János Benjamin Antal 2021-09-30 15:11:51 +02:00 committed by Antonio Andelic
parent be9ed7e879
commit 721eefe263
9 changed files with 444 additions and 67 deletions

View File

@ -596,6 +596,7 @@ enum mgp_error mgp_vertex_get_property(struct mgp_vertex *v, const char *propert
struct mgp_value **result); struct mgp_value **result);
/// Start iterating over properties stored in the given vertex. /// Start iterating over properties stored in the given vertex.
/// The properties of the vertex are copied when the iterator is created, therefore later changes won't affect them.
/// The resulting mgp_properties_iterator needs to be deallocated with /// The resulting mgp_properties_iterator needs to be deallocated with
/// mgp_properties_iterator_destroy. /// mgp_properties_iterator_destroy.
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_properties_iterator. /// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_properties_iterator.
@ -604,16 +605,18 @@ enum mgp_error mgp_vertex_iter_properties(struct mgp_vertex *v, struct mgp_memor
struct mgp_properties_iterator **result); struct mgp_properties_iterator **result);
/// Start iterating over inbound edges of the given vertex. /// Start iterating over inbound edges of the given vertex.
/// The resulting mgp_edges_iterator needs to be deallocated with /// The connection information of the vertex is copied when the iterator is created, therefore later creation or
/// mgp_edges_iterator_destroy. /// deletion of edges won't affect the iterated edges, however the property changes on the edges will be visible.
/// The resulting mgp_edges_iterator needs to be deallocated with mgp_edges_iterator_destroy.
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_edges_iterator. /// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_edges_iterator.
/// Return MGP_ERROR_DELETED_OBJECT if `v` has been deleted. /// Return MGP_ERROR_DELETED_OBJECT if `v` has been deleted.
enum mgp_error mgp_vertex_iter_in_edges(struct mgp_vertex *v, struct mgp_memory *memory, enum mgp_error mgp_vertex_iter_in_edges(struct mgp_vertex *v, struct mgp_memory *memory,
struct mgp_edges_iterator **result); struct mgp_edges_iterator **result);
/// Start iterating over outbound edges of the given vertex. /// Start iterating over outbound edges of the given vertex.
/// The resulting mgp_edges_iterator needs to be deallocated with /// The connection information of the vertex is copied when the iterator is created, therefore later creation or
/// mgp_edges_iterator_destroy. /// deletion of edges won't affect the iterated edges, however the property changes on the edges will be visible.
/// The resulting mgp_edges_iterator needs to be deallocated with mgp_edges_iterator_destroy.
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_edges_iterator. /// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_edges_iterator.
/// Return MGP_ERROR_DELETED_OBJECT if `v` has been deleted. /// Return MGP_ERROR_DELETED_OBJECT if `v` has been deleted.
enum mgp_error mgp_vertex_iter_out_edges(struct mgp_vertex *v, struct mgp_memory *memory, enum mgp_error mgp_vertex_iter_out_edges(struct mgp_vertex *v, struct mgp_memory *memory,
@ -693,6 +696,7 @@ enum mgp_error mgp_edge_get_property(struct mgp_edge *e, const char *property_na
enum mgp_error mgp_edge_set_property(struct mgp_edge *e, const char *property_name, struct mgp_value *property_value); enum mgp_error mgp_edge_set_property(struct mgp_edge *e, const char *property_name, struct mgp_value *property_value);
/// Start iterating over properties stored in the given edge. /// Start iterating over properties stored in the given edge.
/// The properties of the edge are copied when the iterator is created, therefore later changes won't affect them.
/// Resulting mgp_properties_iterator needs to be deallocated with /// Resulting mgp_properties_iterator needs to be deallocated with
/// mgp_properties_iterator_destroy. /// mgp_properties_iterator_destroy.
/// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_properties_iterator. /// Return MGP_ERROR_UNABLE_TO_ALLOCATE if unable to allocate a mgp_properties_iterator.

View File

@ -192,6 +192,9 @@ class Properties:
""" """
Iterate over the properties. Iterate over the properties.
Doesnt return a dynamic view of the properties, but copies the
current properties.
Raise InvalidContextError. Raise InvalidContextError.
Raise UnableToAllocateError if unable to allocate an iterator. Raise UnableToAllocateError if unable to allocate an iterator.
Raise DeletedObjectError if the object has been deleted. Raise DeletedObjectError if the object has been deleted.
@ -210,6 +213,9 @@ class Properties:
""" """
Iterate over property names. Iterate over property names.
Doesnt return a dynamic view of the property names, but copies the
name of the current properties.
Raise InvalidContextError. Raise InvalidContextError.
Raise UnableToAllocateError if unable to allocate an iterator. Raise UnableToAllocateError if unable to allocate an iterator.
Raise DeletedObjectError if the object has been deleted. Raise DeletedObjectError if the object has been deleted.
@ -223,6 +229,9 @@ class Properties:
""" """
Iterate over property values. Iterate over property values.
Doesnt return a dynamic view of the property values, but copies the
value of the current properties.
Raise InvalidContextError. Raise InvalidContextError.
Raise UnableToAllocateError if unable to allocate an iterator. Raise UnableToAllocateError if unable to allocate an iterator.
Raise DeletedObjectError if the object has been deleted. Raise DeletedObjectError if the object has been deleted.
@ -543,6 +552,9 @@ class Vertex:
""" """
Iterate over inbound edges of the vertex. Iterate over inbound edges of the vertex.
Doesnt return a dynamic view of the edges, but copies the
current inbound edges.
Raise InvalidContextError. Raise InvalidContextError.
Raise UnableToAllocateError if unable to allocate an iterator. Raise UnableToAllocateError if unable to allocate an iterator.
Raise DeletedObjectError if `self` has been deleted. Raise DeletedObjectError if `self` has been deleted.
@ -562,6 +574,9 @@ class Vertex:
""" """
Iterate over outbound edges of the vertex. Iterate over outbound edges of the vertex.
Doesnt return a dynamic view of the edges, but copies the
current outbound edges.
Raise InvalidContextError. Raise InvalidContextError.
Raise UnableToAllocateError if unable to allocate an iterator. Raise UnableToAllocateError if unable to allocate an iterator.
Raise DeletedObjectError if `self` has been deleted. Raise DeletedObjectError if `self` has been deleted.

View File

@ -1704,7 +1704,8 @@ cpp<#
(memory-limit "Expression *" :initval "nullptr" :scope :public (memory-limit "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer :slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression")) :slk-load (slk-load-ast-pointer "Expression"))
(memory-scale "size_t" :initval "1024U" :scope :public)) (memory-scale "size_t" :initval "1024U" :scope :public)
(is_write :bool :scope :public))
(:public (:public
#>cpp #>cpp
CallProcedure() = default; CallProcedure() = default;

View File

@ -651,16 +651,31 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
bool has_return = false; bool has_return = false;
bool has_optional_match = false; bool has_optional_match = false;
bool has_call_procedure = false; bool has_call_procedure = false;
bool calls_write_procedure = false;
bool has_any_update = false;
bool has_load_csv = false; bool has_load_csv = false;
auto check_write_procedure = [&calls_write_procedure](const std::string_view clause) {
if (calls_write_procedure) {
throw SemanticException(
"{} can't be put after calling a writeable procedure, only RETURN clause can be put after.", clause);
}
};
for (Clause *clause : single_query->clauses_) { for (Clause *clause : single_query->clauses_) {
const auto &clause_type = clause->GetTypeInfo(); const auto &clause_type = clause->GetTypeInfo();
if (utils::IsSubtype(clause_type, CallProcedure::kType)) { if (const auto *call_procedure = utils::Downcast<CallProcedure>(clause); call_procedure != nullptr) {
if (has_return) { if (has_return) {
throw SemanticException("CALL can't be put after RETURN clause."); throw SemanticException("CALL can't be put after RETURN clause.");
} }
check_write_procedure("CALL");
has_call_procedure = true; has_call_procedure = true;
if (call_procedure->is_write_) {
calls_write_procedure = true;
has_update = true;
}
} else if (utils::IsSubtype(clause_type, Unwind::kType)) { } else if (utils::IsSubtype(clause_type, Unwind::kType)) {
check_write_procedure("UNWIND");
if (has_update || has_return) { if (has_update || has_return) {
throw SemanticException("UNWIND can't be put after RETURN clause or after an update."); throw SemanticException("UNWIND can't be put after RETURN clause or after an update.");
} }
@ -668,6 +683,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
if (has_load_csv) { if (has_load_csv) {
throw SemanticException("Can't have multiple LOAD CSV clauses in a single query."); throw SemanticException("Can't have multiple LOAD CSV clauses in a single query.");
} }
check_write_procedure("LOAD CSV");
if (has_return) { if (has_return) {
throw SemanticException("LOAD CSV can't be put after RETURN clause."); throw SemanticException("LOAD CSV can't be put after RETURN clause.");
} }
@ -681,6 +697,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
} else if (has_optional_match) { } else if (has_optional_match) {
throw SemanticException("MATCH can't be put after OPTIONAL MATCH."); throw SemanticException("MATCH can't be put after OPTIONAL MATCH.");
} }
check_write_procedure("MATCH");
} else if (utils::IsSubtype(clause_type, Create::kType) || utils::IsSubtype(clause_type, Delete::kType) || } else if (utils::IsSubtype(clause_type, Create::kType) || utils::IsSubtype(clause_type, Delete::kType) ||
utils::IsSubtype(clause_type, SetProperty::kType) || utils::IsSubtype(clause_type, SetProperty::kType) ||
utils::IsSubtype(clause_type, SetProperties::kType) || utils::IsSubtype(clause_type, SetLabels::kType) || utils::IsSubtype(clause_type, SetProperties::kType) || utils::IsSubtype(clause_type, SetLabels::kType) ||
@ -689,7 +706,9 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
if (has_return) { if (has_return) {
throw SemanticException("Update clause can't be used after RETURN."); throw SemanticException("Update clause can't be used after RETURN.");
} }
check_write_procedure("Update clause");
has_update = true; has_update = true;
has_any_update = true;
} else if (utils::IsSubtype(clause_type, Return::kType)) { } else if (utils::IsSubtype(clause_type, Return::kType)) {
if (has_return) { if (has_return) {
throw SemanticException("There can only be one RETURN in a clause."); throw SemanticException("There can only be one RETURN in a clause.");
@ -699,6 +718,7 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
if (has_return) { if (has_return) {
throw SemanticException("RETURN can't be put before WITH."); throw SemanticException("RETURN can't be put before WITH.");
} }
check_write_procedure("WITH");
has_update = has_return = has_optional_match = false; has_update = has_return = has_optional_match = false;
} else { } else {
DLOG_FATAL("Can't happen"); DLOG_FATAL("Can't happen");
@ -709,6 +729,9 @@ antlrcpp::Any CypherMainVisitor::visitSingleQuery(MemgraphCypher::SingleQueryCon
throw SemanticException("Query should either create or update something, or return results!"); throw SemanticException("Query should either create or update something, or return results!");
} }
if (has_any_update && calls_write_procedure) {
throw SemanticException("Write procedures cannot be used in queries that contains any update clauses!");
}
// Construct unique names for anonymous identifiers; // Construct unique names for anonymous identifiers;
int id = 1; int id = 1;
for (auto **identifier : anonymous_identifiers) { for (auto **identifier : anonymous_identifiers) {
@ -809,13 +832,15 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(MemgraphCypher::CallProcedur
call_proc->memory_scale_ = 1024U * 1024U; call_proc->memory_scale_ = 1024U * 1024U;
} }
const auto &maybe_found =
procedure::FindProcedure(procedure::gModuleRegistry, call_proc->procedure_name_, utils::NewDeleteResource());
if (!maybe_found) {
throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_);
}
call_proc->is_write_ = maybe_found->second->is_write_procedure;
auto *yield_ctx = ctx->yieldProcedureResults(); auto *yield_ctx = ctx->yieldProcedureResults();
if (!yield_ctx) { if (!yield_ctx) {
const auto &maybe_found =
procedure::FindProcedure(procedure::gModuleRegistry, call_proc->procedure_name_, utils::NewDeleteResource());
if (!maybe_found) {
throw SemanticException("There is no procedure named '{}'.", call_proc->procedure_name_);
}
if (!maybe_found->second->results.empty()) { if (!maybe_found->second->results.empty()) {
throw SemanticException( throw SemanticException(
"CALL without YIELD may only be used on procedures which do not " "CALL without YIELD may only be used on procedures which do not "

View File

@ -3591,14 +3591,15 @@ UniqueCursorPtr OutputTableStream::MakeCursor(utils::MemoryResource *mem) const
CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, std::vector<Expression *> args, CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, std::vector<Expression *> args,
std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit, std::vector<std::string> fields, std::vector<Symbol> symbols, Expression *memory_limit,
size_t memory_scale) size_t memory_scale, bool is_write)
: input_(input ? input : std::make_shared<Once>()), : input_(input ? input : std::make_shared<Once>()),
procedure_name_(name), procedure_name_(name),
arguments_(args), arguments_(args),
result_fields_(fields), result_fields_(fields),
result_symbols_(symbols), result_symbols_(symbols),
memory_limit_(memory_limit), memory_limit_(memory_limit),
memory_scale_(memory_scale) {} memory_scale_(memory_scale),
is_write_(is_write) {}
ACCEPT_WITH_INPUT(CallProcedure); ACCEPT_WITH_INPUT(CallProcedure);
@ -3741,6 +3742,12 @@ class CallProcedureCursor : public Cursor {
throw QueryRuntimeException("There is no procedure named '{}'.", self_->procedure_name_); throw QueryRuntimeException("There is no procedure named '{}'.", self_->procedure_name_);
} }
const auto &[module, proc] = *maybe_found; const auto &[module, proc] = *maybe_found;
if (proc->is_write_procedure != self_->is_write_) {
auto get_proc_type_str = [](bool is_write) { return is_write ? "write" : "read"; };
throw QueryRuntimeException("The procedure named '{}' was a {} procedure, but changed to be a {} procedure.",
self_->procedure_name_, get_proc_type_str(self_->is_write_),
get_proc_type_str(proc->is_write_procedure));
}
const auto graph_view = proc->is_write_procedure ? storage::View::NEW : storage::View::OLD; const auto graph_view = proc->is_write_procedure ? storage::View::NEW : storage::View::OLD;
ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor, ExpressionEvaluator evaluator(&frame, context.symbol_table, context.evaluation_context, context.db_accessor,
graph_view); graph_view);

View File

@ -2185,14 +2185,15 @@ at once. Instead, each call of the callback should return a single row of the ta
(memory-limit "Expression *" :initval "nullptr" :scope :public (memory-limit "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer :slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression")) :slk-load (slk-load-ast-pointer "Expression"))
(memory-scale "size_t" :initval "1024U" :scope :public)) (memory-scale "size_t" :initval "1024U" :scope :public)
(is_write :bool :scope :public))
(:public (:public
#>cpp #>cpp
CallProcedure() = default; CallProcedure() = default;
CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name, CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name,
std::vector<Expression *> arguments, std::vector<Expression *> arguments,
std::vector<std::string> fields, std::vector<Symbol> symbols, std::vector<std::string> fields, std::vector<Symbol> symbols,
Expression *memory_limit, size_t memory_scale); Expression *memory_limit, size_t memory_scale, bool is_write);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override; bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override; UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;

View File

@ -204,7 +204,7 @@ class RuleBasedPlanner {
// storage::View::NEW. // storage::View::NEW.
input_op = std::make_unique<plan::CallProcedure>( input_op = std::make_unique<plan::CallProcedure>(
std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_, std::move(input_op), call_proc->procedure_name_, call_proc->arguments_, call_proc->result_fields_,
result_symbols, call_proc->memory_limit_, call_proc->memory_scale_); result_symbols, call_proc->memory_limit_, call_proc->memory_scale_, call_proc->is_write_);
} else if (auto *load_csv = utils::Downcast<query::LoadCsv>(clause)) { } else if (auto *load_csv = utils::Downcast<query::LoadCsv>(clause)) {
const auto &row_sym = context.symbol_table->at(*load_csv->row_var_); const auto &row_sym = context.symbol_table->at(*load_csv->row_var_);
context.bound_symbols.insert(row_sym); context.bound_symbols.insert(row_sym);

View File

@ -14,6 +14,8 @@
#include "utils/memory.hpp" #include "utils/memory.hpp"
#include "utils/rw_lock.hpp" #include "utils/rw_lock.hpp"
class CypherMainVisitorTest;
namespace query::procedure { namespace query::procedure {
class Module { class Module {
@ -52,6 +54,8 @@ class ModulePtr final {
/// Thread-safe registration of modules from libraries, uses utils::RWLock. /// Thread-safe registration of modules from libraries, uses utils::RWLock.
class ModuleRegistry final { class ModuleRegistry final {
friend CypherMainVisitorTest;
std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_; std::map<std::string, std::unique_ptr<Module>, std::less<>> modules_;
mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE}; mutable utils::RWLock lock_{utils::RWLock::Priority::WRITE};
std::unique_ptr<utils::MemoryResource> shared_{std::make_unique<utils::ResourceWithOutOfMemoryException>()}; std::unique_ptr<utils::MemoryResource> shared_{std::make_unique<utils::ResourceWithOutOfMemoryException>()};

View File

@ -24,12 +24,13 @@
#include "query/frontend/ast/cypher_main_visitor.hpp" #include "query/frontend/ast/cypher_main_visitor.hpp"
#include "query/frontend/opencypher/parser.hpp" #include "query/frontend/opencypher/parser.hpp"
#include "query/frontend/stripped.hpp" #include "query/frontend/stripped.hpp"
#include "query/procedure/cypher_types.hpp"
#include "query/procedure/mg_procedure_impl.hpp"
#include "query/procedure/module.hpp"
#include "query/typed_value.hpp" #include "query/typed_value.hpp"
#include "utils/string.hpp" #include "utils/string.hpp"
namespace {
using namespace query; using namespace query;
using namespace query::frontend; using namespace query::frontend;
using query::TypedValue; using query::TypedValue;
@ -168,7 +169,79 @@ class CachedAstGenerator : public Base {
AstStorage ast_storage_; AstStorage ast_storage_;
}; };
class CypherMainVisitorTest : public ::testing::TestWithParam<std::shared_ptr<Base>> {}; class MockModule : public procedure::Module {
public:
MockModule(){};
~MockModule() override{};
MockModule(const MockModule &) = delete;
MockModule(MockModule &&) = delete;
MockModule &operator=(const MockModule &) = delete;
MockModule &operator=(MockModule &&) = delete;
bool Close() override { return true; };
const std::map<std::string, mgp_proc, std::less<>> *Procedures() const override { return &procedures; }
const std::map<std::string, mgp_trans, std::less<>> *Transformations() const override { return &transformations; }
std::map<std::string, mgp_proc, std::less<>> procedures{};
std::map<std::string, mgp_trans, std::less<>> transformations{};
};
void DummyProcCallback(mgp_list * /*args*/, mgp_graph * /*graph*/, mgp_result * /*result*/, mgp_memory * /*memory*/){};
enum class ProcedureType { WRITE, READ };
std::string ToString(const ProcedureType type) { return type == ProcedureType::WRITE ? "write" : "read"; }
class CypherMainVisitorTest : public ::testing::TestWithParam<std::shared_ptr<Base>> {
public:
void SetUp() override {
{
auto mock_module_owner = std::make_unique<MockModule>();
mock_module = mock_module_owner.get();
procedure::gModuleRegistry.RegisterModule("mock_module", std::move(mock_module_owner));
}
{
auto mock_module_with_dots_in_name_owner = std::make_unique<MockModule>();
mock_module_with_dots_in_name = mock_module_with_dots_in_name_owner.get();
procedure::gModuleRegistry.RegisterModule("mock_module.with.dots.in.name",
std::move(mock_module_with_dots_in_name_owner));
}
}
void TearDown() override {
// To release any_type
procedure::gModuleRegistry.UnloadAllModules();
}
static void AddProc(MockModule &module, const char *name, const std::vector<std::string_view> &args,
const std::vector<std::string_view> &results, const ProcedureType type) {
utils::MemoryResource *memory = utils::NewDeleteResource();
const bool is_write = type == ProcedureType::WRITE;
mgp_proc proc(name, DummyProcCallback, memory, is_write);
for (const auto arg : args) {
proc.args.emplace_back(utils::pmr::string{arg, memory}, &any_type);
}
for (const auto result : results) {
proc.results.emplace(utils::pmr::string{result, memory}, std::make_pair(&any_type, false));
}
module.procedures.emplace(name, std::move(proc));
}
std::string CreateProcByType(const ProcedureType type, const std::vector<std::string_view> &args) {
const auto proc_name = std::string{"proc_"} + ToString(type);
SCOPED_TRACE(proc_name);
AddProc(*mock_module, proc_name.c_str(), {}, args, type);
return std::string{"mock_module."} + proc_name;
}
static const procedure::AnyType any_type;
MockModule *mock_module{nullptr};
MockModule *mock_module_with_dots_in_name{nullptr};
};
const procedure::AnyType CypherMainVisitorTest::any_type{};
std::shared_ptr<Base> gAstGeneratorTypes[] = { std::shared_ptr<Base> gAstGeneratorTypes[] = {
std::make_shared<AstGenerator>(), std::make_shared<AstGenerator>(),
@ -2556,15 +2629,18 @@ void CheckCallProcedureDefaultMemoryLimit(const TAst &ast, const CallProcedure &
} // namespace } // namespace
TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) { TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) {
AddProc(*mock_module_with_dots_in_name, "proc", {}, {"res"}, ProcedureType::WRITE);
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL proc.with.dots() YIELD res"));
auto *query =
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mock_module.with.dots.in.name.proc() YIELD res"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_); ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_; auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U); ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc); ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots"); ASSERT_EQ(call_proc->procedure_name_, "mock_module.with.dots.in.name.proc");
ASSERT_TRUE(call_proc->arguments_.empty()); ASSERT_TRUE(call_proc->arguments_.empty());
std::vector<std::string> identifier_names; std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size()); identifier_names.reserve(call_proc->result_identifiers_.size());
@ -2579,15 +2655,18 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) {
} }
TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) { TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) {
AddProc(*mock_module, "proc-with-dashes", {}, {"res"}, ProcedureType::READ);
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL `proc-with-dashes`() YIELD res"));
auto *query =
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL `mock_module.proc-with-dashes`() YIELD res"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_); ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_; auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U); ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc); ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc-with-dashes"); ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc-with-dashes");
ASSERT_TRUE(call_proc->arguments_.empty()); ASSERT_TRUE(call_proc->arguments_.empty());
std::vector<std::string> identifier_names; std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size()); identifier_names.reserve(call_proc->result_identifiers_.size());
@ -2603,34 +2682,45 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) {
TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) { TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = auto check_proc = [this, &ast_generator](const ProcedureType type) {
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL proc() YIELD fst, `field-with-dashes`, last_field")); const auto proc_name = std::string{"proc_"} + ToString(type);
ASSERT_TRUE(query); SCOPED_TRACE(proc_name);
ASSERT_TRUE(query->single_query_); const auto fully_qualified_proc_name = std::string{"mock_module."} + proc_name;
auto *single_query = query->single_query_; AddProc(*mock_module, proc_name.c_str(), {}, {"fst", "field-with-dashes", "last_field"}, type);
ASSERT_EQ(single_query->clauses_.size(), 1U); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); fmt::format("CALL {}() YIELD fst, `field-with-dashes`, last_field", fully_qualified_proc_name)));
ASSERT_TRUE(call_proc); ASSERT_TRUE(query);
ASSERT_EQ(call_proc->procedure_name_, "proc"); ASSERT_TRUE(query->single_query_);
ASSERT_TRUE(call_proc->arguments_.empty()); auto *single_query = query->single_query_;
ASSERT_EQ(call_proc->result_fields_.size(), 3U); ASSERT_EQ(single_query->clauses_.size(), 1U);
ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
std::vector<std::string> identifier_names; ASSERT_TRUE(call_proc);
identifier_names.reserve(call_proc->result_identifiers_.size()); ASSERT_EQ(call_proc->is_write_, type == ProcedureType::WRITE);
for (const auto *identifier : call_proc->result_identifiers_) { ASSERT_EQ(call_proc->procedure_name_, fully_qualified_proc_name);
ASSERT_TRUE(identifier->user_declared_); ASSERT_TRUE(call_proc->arguments_.empty());
identifier_names.push_back(identifier->name_); ASSERT_EQ(call_proc->result_fields_.size(), 3U);
} ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size());
std::vector<std::string> expected_names{"fst", "field-with-dashes", "last_field"}; std::vector<std::string> identifier_names;
ASSERT_EQ(identifier_names, expected_names); identifier_names.reserve(call_proc->result_identifiers_.size());
ASSERT_EQ(identifier_names, call_proc->result_fields_); for (const auto *identifier : call_proc->result_identifiers_) {
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); ASSERT_TRUE(identifier->user_declared_);
identifier_names.push_back(identifier->name_);
}
std::vector<std::string> expected_names{"fst", "field-with-dashes", "last_field"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
};
check_proc(ProcedureType::READ);
check_proc(ProcedureType::WRITE);
} }
TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) { TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) {
AddProc(*mock_module, "proc", {}, {"fst", "snd", "thrd"}, ProcedureType::READ);
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = auto *query =
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL proc() YIELD fst AS res1, snd AS " dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mock_module.proc() YIELD fst AS res1, snd AS "
"`result-with-dashes`, thrd AS last_result")); "`result-with-dashes`, thrd AS last_result"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_); ASSERT_TRUE(query->single_query_);
@ -2638,7 +2728,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) {
ASSERT_EQ(single_query->clauses_.size(), 1U); ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc); ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc"); ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc");
ASSERT_TRUE(call_proc->arguments_.empty()); ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_EQ(call_proc->result_fields_.size(), 3U); ASSERT_EQ(call_proc->result_fields_.size(), 3U);
ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size()); ASSERT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size());
@ -2656,15 +2746,16 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) {
} }
TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) { TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) {
AddProc(*mock_module, "proc", {"arg1", "arg2", "arg3"}, {"res"}, ProcedureType::READ);
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL proc(0, 1, 2) YIELD res")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mock_module.proc(0, 1, 2) YIELD res"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_); ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_; auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U); ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc); ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc"); ASSERT_EQ(call_proc->procedure_name_, "mock_module.proc");
ASSERT_EQ(call_proc->arguments_.size(), 3U); ASSERT_EQ(call_proc->arguments_.size(), 3U);
for (int64_t i = 0; i < 3; ++i) { for (int64_t i = 0; i < 3; ++i) {
ast_generator.CheckLiteral(call_proc->arguments_[i], i); ast_generator.CheckLiteral(call_proc->arguments_[i], i);
@ -2681,7 +2772,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) {
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
} }
TEST_P(CypherMainVisitorTest, CallYieldAsterisk) { TEST_P(CypherMainVisitorTest, CallProcedureYieldAsterisk) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.procedures() YIELD *")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.procedures() YIELD *"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
@ -2703,7 +2794,7 @@ TEST_P(CypherMainVisitorTest, CallYieldAsterisk) {
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
} }
TEST_P(CypherMainVisitorTest, CallYieldAsteriskReturnAsterisk) { TEST_P(CypherMainVisitorTest, CallProcedureYieldAsteriskReturnAsterisk) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.procedures() YIELD * RETURN *")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.procedures() YIELD * RETURN *"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
@ -2728,7 +2819,7 @@ TEST_P(CypherMainVisitorTest, CallYieldAsteriskReturnAsterisk) {
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
} }
TEST_P(CypherMainVisitorTest, CallWithoutYield) { TEST_P(CypherMainVisitorTest, CallProcedureWithoutYield) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all()")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all()"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
@ -2744,7 +2835,7 @@ TEST_P(CypherMainVisitorTest, CallWithoutYield) {
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc); CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
} }
TEST_P(CypherMainVisitorTest, CallWithMemoryLimitWithoutYield) { TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimitWithoutYield) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = auto *query =
dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 KB")); dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 KB"));
@ -2762,7 +2853,7 @@ TEST_P(CypherMainVisitorTest, CallWithMemoryLimitWithoutYield) {
ASSERT_EQ(call_proc->memory_scale_, 1024); ASSERT_EQ(call_proc->memory_scale_, 1024);
} }
TEST_P(CypherMainVisitorTest, CallWithMemoryUnlimitedWithoutYield) { TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimitedWithoutYield) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED")); auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
@ -2781,14 +2872,14 @@ TEST_P(CypherMainVisitorTest, CallWithMemoryUnlimitedWithoutYield) {
TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimit) { TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimit) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>( auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL proc.with.dots() PROCEDURE MEMORY LIMIT 32 MB YIELD res")); ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY LIMIT 32 MB YIELD res"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_); ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_; auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U); ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc); ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots"); ASSERT_EQ(call_proc->procedure_name_, "mg.load_all");
ASSERT_TRUE(call_proc->arguments_.empty()); ASSERT_TRUE(call_proc->arguments_.empty());
std::vector<std::string> identifier_names; std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size()); identifier_names.reserve(call_proc->result_identifiers_.size());
@ -2805,15 +2896,15 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimit) {
TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimited) { TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimited) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>( auto *query =
ast_generator.ParseQuery("CALL proc.with.dots() PROCEDURE MEMORY UNLIMITED YIELD res")); dynamic_cast<CypherQuery *>(ast_generator.ParseQuery("CALL mg.load_all() PROCEDURE MEMORY UNLIMITED YIELD res"));
ASSERT_TRUE(query); ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_); ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_; auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U); ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]); auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc); ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots"); ASSERT_EQ(call_proc->procedure_name_, "mg.load_all");
ASSERT_TRUE(call_proc->arguments_.empty()); ASSERT_TRUE(call_proc->arguments_.empty());
std::vector<std::string> identifier_names; std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size()); identifier_names.reserve(call_proc->result_identifiers_.size());
@ -2827,6 +2918,243 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimited) {
ASSERT_FALSE(call_proc->memory_limit_); ASSERT_FALSE(call_proc->memory_limit_);
} }
namespace {
template <typename TException = SyntaxException>
void TestInvalidQuery(const auto &query, Base &ast_generator) {
EXPECT_THROW(ast_generator.ParseQuery(query), TException) << query;
}
template <typename TException = SyntaxException>
void TestInvalidQueryWithMessage(const auto &query, Base &ast_generator, const std::string_view message) {
bool exception_is_thrown = false;
try {
ast_generator.ParseQuery(query);
} catch (const TException &se) {
EXPECT_EQ(std::string_view{se.what()}, message);
exception_is_thrown = true;
} catch (...) {
FAIL() << "Unexpected exception";
}
EXPECT_TRUE(exception_is_thrown);
}
void CheckParsedCallProcedure(const CypherQuery &query, Base &ast_generator,
const std::string_view fully_qualified_proc_name,
const std::vector<std::string_view> &args, const ProcedureType type,
const size_t clauses_size, const size_t call_procedure_index) {
ASSERT_NE(query.single_query_, nullptr);
auto *single_query = query.single_query_;
EXPECT_EQ(single_query->clauses_.size(), clauses_size);
ASSERT_FALSE(single_query->clauses_.empty());
ASSERT_LT(call_procedure_index, clauses_size);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[call_procedure_index]);
ASSERT_NE(call_proc, nullptr);
EXPECT_EQ(call_proc->procedure_name_, fully_qualified_proc_name);
EXPECT_TRUE(call_proc->arguments_.empty());
EXPECT_EQ(call_proc->result_fields_.size(), 2U);
EXPECT_EQ(call_proc->result_identifiers_.size(), call_proc->result_fields_.size());
std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size());
for (const auto *identifier : call_proc->result_identifiers_) {
EXPECT_TRUE(identifier->user_declared_);
identifier_names.push_back(identifier->name_);
}
std::vector<std::string> args_as_str{};
std::transform(args.begin(), args.end(), std::back_inserter(args_as_str),
[](const std::string_view &arg) { return std::string{arg}; });
EXPECT_EQ(identifier_names, args_as_str);
EXPECT_EQ(identifier_names, call_proc->result_fields_);
ASSERT_EQ(call_proc->is_write_, type == ProcedureType::WRITE);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
};
} // namespace
TEST_P(CypherMainVisitorTest, CallProcedureMultipleQueryPartsAfter) {
auto &ast_generator = *GetParam();
constexpr std::string_view fst{"fst"};
constexpr std::string_view snd{"snd"};
const std::vector args{fst, snd};
const auto read_proc = CreateProcByType(ProcedureType::READ, args);
const auto write_proc = CreateProcByType(ProcedureType::WRITE, args);
const auto check_parsed_call_proc = [&ast_generator, &args](const CypherQuery &query,
const std::string_view fully_qualified_proc_name,
const ProcedureType type, const size_t clause_size) {
CheckParsedCallProcedure(query, ast_generator, fully_qualified_proc_name, args, type, clause_size, 0);
};
{
SCOPED_TRACE("Read query part");
{
SCOPED_TRACE("With WITH");
constexpr std::string_view kQueryWithWith{"CALL {}() YIELD {},{} WITH {},{} UNWIND {} as u RETURN u"};
constexpr size_t kQueryParts{4};
{
SCOPED_TRACE("Write proc");
const auto query_str = fmt::format(kQueryWithWith, write_proc, fst, snd, fst, snd, fst);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator,
"WITH can't be put after calling a writeable procedure, only RETURN clause can be put after.");
}
{
SCOPED_TRACE("Read proc");
const auto query_str = fmt::format(kQueryWithWith, read_proc, fst, snd, fst, snd, fst);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts);
}
}
{
SCOPED_TRACE("Without WITH");
constexpr std::string_view kQueryWithoutWith{"CALL {}() YIELD {},{} UNWIND {} as u RETURN u"};
constexpr size_t kQueryParts{3};
{
SCOPED_TRACE("Write proc");
const auto query_str = fmt::format(kQueryWithoutWith, write_proc, fst, snd, fst);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator,
"UNWIND can't be put after calling a writeable procedure, only RETURN clause can be put after.");
}
{
SCOPED_TRACE("Read proc");
const auto query_str = fmt::format(kQueryWithoutWith, read_proc, fst, snd, fst);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts);
}
}
}
{
SCOPED_TRACE("Write query part");
{
SCOPED_TRACE("With WITH");
constexpr std::string_view kQueryWithWith{"CALL {}() YIELD {},{} WITH {},{} CREATE(n {{prop : {}}}) RETURN n"};
constexpr size_t kQueryParts{4};
{
SCOPED_TRACE("Write proc");
const auto query_str = fmt::format(kQueryWithWith, write_proc, fst, snd, fst, snd, fst);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator,
"WITH can't be put after calling a writeable procedure, only RETURN clause can be put after.");
}
{
SCOPED_TRACE("Read proc");
const auto query_str = fmt::format(kQueryWithWith, read_proc, fst, snd, fst, snd, fst);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts);
}
}
{
SCOPED_TRACE("Without WITH");
constexpr std::string_view kQueryWithoutWith{"CALL {}() YIELD {},{} CREATE(n {{prop : {}}}) RETURN n"};
constexpr size_t kQueryParts{3};
{
SCOPED_TRACE("Write proc");
const auto query_str = fmt::format(kQueryWithoutWith, write_proc, fst, snd, fst);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator,
"Update clause can't be put after calling a writeable procedure, only RETURN clause can be put after.");
}
{
SCOPED_TRACE("Read proc");
const auto query_str = fmt::format(kQueryWithoutWith, read_proc, fst, snd, fst, snd, fst);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts);
}
}
}
}
TEST_P(CypherMainVisitorTest, CallProcedureMultipleQueryPartsBefore) {
auto &ast_generator = *GetParam();
constexpr std::string_view fst{"fst"};
constexpr std::string_view snd{"snd"};
const std::vector args{fst, snd};
const auto read_proc = CreateProcByType(ProcedureType::READ, args);
const auto write_proc = CreateProcByType(ProcedureType::WRITE, args);
const auto check_parsed_call_proc = [&ast_generator, &args](const CypherQuery &query,
const std::string_view fully_qualified_proc_name,
const ProcedureType type, const size_t clause_size) {
CheckParsedCallProcedure(query, ast_generator, fully_qualified_proc_name, args, type, clause_size, clause_size - 2);
};
{
SCOPED_TRACE("Read query part");
constexpr std::string_view kQueryWithReadQueryPart{"MATCH (n) CALL {}() YIELD * RETURN *"};
constexpr size_t kQueryParts{3};
{
SCOPED_TRACE("Write proc");
const auto query_str = fmt::format(kQueryWithReadQueryPart, write_proc);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, write_proc, ProcedureType::WRITE, kQueryParts);
}
{
SCOPED_TRACE("Read proc");
const auto query_str = fmt::format(kQueryWithReadQueryPart, read_proc);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts);
}
}
{
SCOPED_TRACE("Write query part");
constexpr std::string_view kQueryWithWriteQueryPart{"CREATE (n) WITH n CALL {}() YIELD * RETURN *"};
constexpr size_t kQueryParts{4};
{
SCOPED_TRACE("Write proc");
const auto query_str = fmt::format(kQueryWithWriteQueryPart, write_proc, fst, snd, fst);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator, "Write procedures cannot be used in queries that contains any update clauses!");
}
{
SCOPED_TRACE("Read proc");
const auto query_str = fmt::format(kQueryWithWriteQueryPart, read_proc, fst, snd, fst, snd, fst);
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
check_parsed_call_proc(*query, read_proc, ProcedureType::READ, kQueryParts);
}
}
}
TEST_P(CypherMainVisitorTest, CallProcedureMultipleProcedures) {
auto &ast_generator = *GetParam();
constexpr std::string_view fst{"fst"};
constexpr std::string_view snd{"snd"};
const std::vector args{fst, snd};
const auto read_proc = CreateProcByType(ProcedureType::READ, args);
const auto write_proc = CreateProcByType(ProcedureType::WRITE, args);
{
SCOPED_TRACE("Read then write");
const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", read_proc, write_proc);
constexpr size_t kQueryParts{3};
const auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(query_str));
ASSERT_NE(query, nullptr);
CheckParsedCallProcedure(*query, ast_generator, read_proc, args, ProcedureType::READ, kQueryParts, 0);
CheckParsedCallProcedure(*query, ast_generator, write_proc, args, ProcedureType::WRITE, kQueryParts, 1);
}
{
SCOPED_TRACE("Write then read");
const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", write_proc, read_proc);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator,
"CALL can't be put after calling a writeable procedure, only RETURN clause can be put after.");
}
{
SCOPED_TRACE("Write twice");
const auto query_str = fmt::format("CALL {}() YIELD * CALL {}() YIELD * RETURN *", write_proc, write_proc);
TestInvalidQueryWithMessage<SemanticException>(
query_str, ast_generator,
"CALL can't be put after calling a writeable procedure, only RETURN clause can be put after.");
}
}
TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) { TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
ASSERT_THROW(ast_generator.ParseQuery("CALL proc-with-dashes()"), SyntaxException); ASSERT_THROW(ast_generator.ParseQuery("CALL proc-with-dashes()"), SyntaxException);
@ -3062,13 +3390,6 @@ TEST_P(CypherMainVisitorTest, MemoryLimit) {
} }
} }
namespace {
template <typename TException = SyntaxException>
void TestInvalidQuery(const auto &query, Base &ast_generator) {
EXPECT_THROW(ast_generator.ParseQuery(query), TException) << query;
}
} // namespace
TEST_P(CypherMainVisitorTest, DropTrigger) { TEST_P(CypherMainVisitorTest, DropTrigger) {
auto &ast_generator = *GetParam(); auto &ast_generator = *GetParam();
@ -3460,4 +3781,3 @@ TEST_P(CypherMainVisitorTest, SettingQuery) {
validate_setting_query("SET DATABASE SETTING 'setting' TO 'value'", SettingQuery::Action::SET_SETTING, validate_setting_query("SET DATABASE SETTING 'setting' TO 'value'", SettingQuery::Action::SET_SETTING,
TypedValue{"setting"}, TypedValue{"value"}); TypedValue{"setting"}, TypedValue{"value"});
} }
} // namespace