Add syntax for limiting memory of CALL

Reviewers: mferencevic, ipaljak

Reviewed By: ipaljak

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D2602
This commit is contained in:
Teon Banek 2019-12-11 15:04:32 +01:00
parent f27682dc26
commit 0c111d52dc
9 changed files with 205 additions and 26 deletions

View File

@ -1596,7 +1596,11 @@ cpp<#
(result-identifiers "std::vector<Identifier *>"
:scope :public
:slk-save #'slk-save-ast-vector
:slk-load (slk-load-ast-vector "Identifier")))
:slk-load (slk-load-ast-vector "Identifier"))
(memory-limit "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(memory-scale "size_t" :initval "1024U" :scope :public))
(:public
#>cpp
CallProcedure() = default;

View File

@ -368,6 +368,22 @@ antlrcpp::Any CypherMainVisitor::visitCallProcedure(
for (auto *expr : ctx->expression()) {
call_proc->arguments_.push_back(expr->accept(this));
}
if (auto *memory_limit_ctx = ctx->callProcedureMemoryLimit()) {
if (memory_limit_ctx->LIMIT()) {
call_proc->memory_limit_ = memory_limit_ctx->literal()->accept(this);
if (memory_limit_ctx->MB()) {
call_proc->memory_scale_ = 1024U * 1024U;
} else {
CHECK(memory_limit_ctx->KB());
call_proc->memory_scale_ = 1024U;
}
}
} else {
// Default to 100 MB
call_proc->memory_limit_ =
storage_->Create<PrimitiveLiteral>(TypedValue(100));
call_proc->memory_scale_ = 1024U * 1024U;
}
auto *yield_ctx = ctx->yieldProcedureResults();
if (!yield_ctx) {
const auto &maybe_found = procedure::FindProcedure(

View File

@ -108,10 +108,12 @@ with : WITH ( DISTINCT )? returnBody ( where )? ;
cypherReturn : RETURN ( DISTINCT )? returnBody ;
callProcedure : CALL procedureName '(' ( expression ( ',' expression )* )? ')' ( yieldProcedureResults )? ;
callProcedure : CALL procedureName '(' ( expression ( ',' expression )* )? ')' ( callProcedureMemoryLimit )? ( yieldProcedureResults )? ;
procedureName : symbolicName ( '.' symbolicName )* ;
callProcedureMemoryLimit : MEMORY ( UNLIMITED | LIMIT literal ( MB | KB ) ) ;
yieldProcedureResults : YIELD ( '*' | ( procedureResult ( ',' procedureResult )* ) ) ;
procedureResult : ( variable AS variable ) | variable ;

View File

@ -103,10 +103,13 @@ IN : I N ;
INDEX : I N D E X ;
INFO : I N F O ;
IS : I S ;
KB : K B ;
KEY : K E Y ;
LIMIT : L I M I T ;
L_SKIP : S K I P ;
MATCH : M A T C H ;
MB : M B ;
MEMORY : M E M O R Y ;
MERGE : M E R G E ;
NODE : N O D E ;
NONE : N O N E ;
@ -129,6 +132,7 @@ THEN : T H E N ;
TRUE : T R U E ;
UNION : U N I O N ;
UNIQUE : U N I Q U E ;
UNLIMITED : U N L I M I T E D ;
UNWIND : U N W I N D ;
WHEN : W H E N ;
WHERE : W H E R E ;

View File

@ -88,10 +88,10 @@ const trie::Trie kKeywords = {
"when", "then", "else", "end", "count", "filter",
"extract", "any", "none", "single", "true", "false",
"reduce", "coalesce", "user", "password", "alter", "drop",
"show", "stats",
"unique", "explain", "profile",
"storage", "index", "info", "exists", "assert", "constraint",
"node", "key", "dump", "database", "call", "yield"};
"show", "stats", "unique", "explain", "profile", "storage",
"index", "info", "exists", "assert", "constraint", "node",
"key", "dump", "database", "call", "yield", "memory",
"mb", "kb", "unlimited"};
// Unicode codepoints that are allowed at the start of the unescaped name.
const std::bitset<kBitsetSize> kUnescapedNameAllowedStarts(std::string(

View File

@ -3714,12 +3714,15 @@ UniqueCursorPtr OutputTableStream::MakeCursor(
CallProcedure::CallProcedure(std::shared_ptr<LogicalOperator> input,
std::string name, std::vector<Expression *> args,
std::vector<std::string> fields,
std::vector<Symbol> symbols)
std::vector<Symbol> symbols,
Expression *memory_limit, size_t memory_scale)
: input_(input ? input : std::make_shared<Once>()),
procedure_name_(name),
arguments_(args),
result_fields_(fields),
result_symbols_(symbols) {}
result_symbols_(symbols),
memory_limit_(memory_limit),
memory_scale_(memory_scale) {}
ACCEPT_WITH_INPUT(CallProcedure);
@ -3736,11 +3739,26 @@ std::vector<Symbol> CallProcedure::ModifiedSymbols(
namespace {
std::optional<size_t> EvalMemoryLimit(ExpressionEvaluator *eval,
Expression *memory_limit,
size_t memory_scale) {
if (!memory_limit) return std::nullopt;
auto limit_value = memory_limit->Accept(*eval);
if (!limit_value.IsInt() || limit_value.ValueInt() <= 0)
throw QueryRuntimeException("Memory limit must be a non-negative integer.");
size_t limit = limit_value.ValueInt();
if (std::numeric_limits<size_t>::max() / memory_scale < limit)
throw QueryRuntimeException("Memory limit overflow.");
return limit * memory_scale;
}
void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
const mgp_proc &proc,
const std::vector<Expression *> &args,
const mgp_graph &graph, ExpressionEvaluator *evaluator,
utils::MemoryResource *memory, mgp_result *result) {
utils::MemoryResource *memory,
std::optional<size_t> memory_limit,
mgp_result *result) {
static_assert(std::uses_allocator_v<mgp_value, utils::Allocator<mgp_value>>,
"Expected mgp_value to use custom allocator and makes STL "
"containers aware of that");
@ -3791,17 +3809,26 @@ void CallCustomProcedure(const std::string_view &fully_qualified_procedure_name,
for (size_t i = passed_in_opt_args; i < proc.opt_args.size(); ++i) {
proc_args.elems.emplace_back(std::get<2>(proc.opt_args[i]), &graph);
}
// TODO: Add syntax for controlling procedure memory limits.
utils::LimitedMemoryResource limited_mem(memory,
100 * 1024 * 1024 /* 100 MB */);
mgp_memory proc_memory{&limited_mem};
CHECK(result->signature == &proc.results);
// TODO: What about cross library boundary exceptions? OMG C++?!
proc.cb(&proc_args, &graph, result, &proc_memory);
size_t leaked_bytes = limited_mem.GetAllocatedBytes();
LOG_IF(WARNING, leaked_bytes > 0U)
<< "Query procedure '" << fully_qualified_procedure_name << "' leaked "
<< leaked_bytes << " *tracked* bytes";
if (memory_limit) {
LOG(INFO) << "Running '" << fully_qualified_procedure_name
<< "' with memory limit of " << *memory_limit << " bytes";
utils::LimitedMemoryResource limited_mem(memory, *memory_limit);
mgp_memory proc_memory{&limited_mem};
CHECK(result->signature == &proc.results);
// TODO: What about cross library boundary exceptions? OMG C++?!
proc.cb(&proc_args, &graph, result, &proc_memory);
size_t leaked_bytes = limited_mem.GetAllocatedBytes();
LOG_IF(WARNING, leaked_bytes > 0U)
<< "Query procedure '" << fully_qualified_procedure_name << "' leaked "
<< leaked_bytes << " *tracked* bytes";
} else {
// TODO: Add a tracking MemoryResource without limits, so that we report
// memory leaks in procedure.
mgp_memory proc_memory{memory};
CHECK(result->signature == &proc.results);
// TODO: What about cross library boundary exceptions? OMG C++?!
proc.cb(&proc_args, &graph, result, &proc_memory);
}
}
} // namespace
@ -3867,9 +3894,11 @@ class CallProcedureCursor : public Cursor {
// TODO: This will probably need to be changed when we add support for
// generator like procedures which yield a new result on each invocation.
auto *memory = context.evaluation_context.memory;
auto memory_limit = EvalMemoryLimit(&evaluator, self_->memory_limit_,
self_->memory_scale_);
mgp_graph graph{context.db_accessor, graph_view};
CallCustomProcedure(self_->procedure_name_, *proc, self_->arguments_,
graph, &evaluator, memory, &result_);
graph, &evaluator, memory, memory_limit, &result_);
// Reset result_.signature to nullptr, because outside of this scope we
// will no longer hold a lock on the `module`. If someone were to reload
// it, the pointer would be invalid.

View File

@ -2069,13 +2069,18 @@ at once. Instead, each call of the callback should return a single row of the ta
:slk-save #'slk-save-ast-vector
:slk-load (slk-load-ast-vector "Expression"))
(result-fields "std::vector<std::string>" :scope :public)
(result-symbols "std::vector<Symbol>" :scope :public))
(result-symbols "std::vector<Symbol>" :scope :public)
(memory-limit "Expression *" :initval "nullptr" :scope :public
:slk-save #'slk-save-ast-pointer
:slk-load (slk-load-ast-pointer "Expression"))
(memory-scale "size_t" :initval "1024U" :scope :public))
(:public
#>cpp
CallProcedure() = default;
CallProcedure(std::shared_ptr<LogicalOperator> input, std::string name,
std::vector<Expression *> arguments,
std::vector<std::string> fields, std::vector<Symbol> symbols);
std::vector<std::string> fields, std::vector<Symbol> symbols,
Expression *memory_limit, size_t memory_scale);
bool Accept(HierarchicalLogicalOperatorVisitor &visitor) override;
UniqueCursorPtr MakeCursor(utils::MemoryResource *) const override;

View File

@ -223,7 +223,8 @@ class RuleBasedPlanner {
// storage::View::NEW.
input_op = std::make_unique<plan::CallProcedure>(
std::move(input_op), call_proc->procedure_name_,
call_proc->arguments_, call_proc->result_fields_, result_symbols);
call_proc->arguments_, call_proc->result_fields_, result_symbols,
call_proc->memory_limit_, call_proc->memory_scale_);
} else {
throw utils::NotYetImplemented(
"clause '{}' conversion to operator(s)",

View File

@ -52,8 +52,9 @@ class Base {
}
template <class TValue>
void CheckLiteral(Expression *expression, const TValue &expected,
const std::optional<int> &token_position = std::nullopt) {
void CheckLiteral(
Expression *expression, const TValue &expected,
const std::optional<int> &token_position = std::nullopt) const {
TypedValue value;
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
TypedValue expected_tv(expected);
@ -2669,6 +2670,19 @@ TEST_P(CypherMainVisitorTest, DumpDatabase) {
ASSERT_TRUE(query);
}
namespace {
template <class TAst>
void CheckCallProcedureDefaultMemoryLimit(const TAst &ast,
const CallProcedure &call_proc) {
// Should be 100 MB
auto *literal = dynamic_cast<PrimitiveLiteral *>(call_proc.memory_limit_);
ASSERT_TRUE(literal);
TypedValue value(literal->value_);
ASSERT_TRUE(TypedValue::BoolEqual{}(value, TypedValue(100)));
ASSERT_EQ(call_proc.memory_scale_, 1024 * 1024);
}
} // namespace
TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
@ -2690,6 +2704,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithDotsInName) {
std::vector<std::string> expected_names{"res"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) {
@ -2713,6 +2728,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithDashesInName) {
std::vector<std::string> expected_names{"res"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) {
@ -2740,6 +2756,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithYieldSomeFields) {
"last_field"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) {
@ -2769,6 +2786,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithYieldAliasedFields) {
ASSERT_EQ(identifier_names, aliased_names);
std::vector<std::string> field_names{"fst", "snd", "thrd"};
ASSERT_EQ(call_proc->result_fields_, field_names);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) {
@ -2795,6 +2813,7 @@ TEST_P(CypherMainVisitorTest, CallProcedureWithArguments) {
std::vector<std::string> expected_names{"res"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallYieldAsterisk) {
@ -2818,6 +2837,7 @@ TEST_P(CypherMainVisitorTest, CallYieldAsterisk) {
std::vector<std::string> expected_names{"name", "signature"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallYieldAsteriskReturnAsterisk) {
@ -2844,6 +2864,7 @@ TEST_P(CypherMainVisitorTest, CallYieldAsteriskReturnAsterisk) {
std::vector<std::string> expected_names{"name", "signature"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallWithoutYield) {
@ -2860,6 +2881,91 @@ TEST_P(CypherMainVisitorTest, CallWithoutYield) {
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_TRUE(call_proc->result_fields_.empty());
ASSERT_TRUE(call_proc->result_identifiers_.empty());
CheckCallProcedureDefaultMemoryLimit(ast_generator, *call_proc);
}
TEST_P(CypherMainVisitorTest, CallWithMemoryLimitWithoutYield) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL mg.reload_all() MEMORY LIMIT 32 KB"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "mg.reload_all");
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_TRUE(call_proc->result_fields_.empty());
ASSERT_TRUE(call_proc->result_identifiers_.empty());
ast_generator.CheckLiteral(call_proc->memory_limit_, 32);
ASSERT_EQ(call_proc->memory_scale_, 1024);
}
TEST_P(CypherMainVisitorTest, CallWithMemoryUnlimitedWithoutYield) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(
ast_generator.ParseQuery("CALL mg.reload_all() MEMORY UNLIMITED"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "mg.reload_all");
ASSERT_TRUE(call_proc->arguments_.empty());
ASSERT_TRUE(call_proc->result_fields_.empty());
ASSERT_TRUE(call_proc->result_identifiers_.empty());
ASSERT_FALSE(call_proc->memory_limit_);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryLimit) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(
"CALL proc.with.dots() MEMORY LIMIT 32 MB YIELD res"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots");
ASSERT_TRUE(call_proc->arguments_.empty());
std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size());
for (const auto *identifier : call_proc->result_identifiers_) {
ASSERT_TRUE(identifier->user_declared_);
identifier_names.push_back(identifier->name_);
}
std::vector<std::string> expected_names{"res"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
ast_generator.CheckLiteral(call_proc->memory_limit_, 32);
ASSERT_EQ(call_proc->memory_scale_, 1024 * 1024);
}
TEST_P(CypherMainVisitorTest, CallProcedureWithMemoryUnlimited) {
auto &ast_generator = *GetParam();
auto *query = dynamic_cast<CypherQuery *>(ast_generator.ParseQuery(
"CALL proc.with.dots() MEMORY UNLIMITED YIELD res"));
ASSERT_TRUE(query);
ASSERT_TRUE(query->single_query_);
auto *single_query = query->single_query_;
ASSERT_EQ(single_query->clauses_.size(), 1U);
auto *call_proc = dynamic_cast<CallProcedure *>(single_query->clauses_[0]);
ASSERT_TRUE(call_proc);
ASSERT_EQ(call_proc->procedure_name_, "proc.with.dots");
ASSERT_TRUE(call_proc->arguments_.empty());
std::vector<std::string> identifier_names;
identifier_names.reserve(call_proc->result_identifiers_.size());
for (const auto *identifier : call_proc->result_identifiers_) {
ASSERT_TRUE(identifier->user_declared_);
identifier_names.push_back(identifier->name_);
}
std::vector<std::string> expected_names{"res"};
ASSERT_EQ(identifier_names, expected_names);
ASSERT_EQ(identifier_names, call_proc->result_fields_);
ASSERT_FALSE(call_proc->memory_limit_);
}
TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) {
@ -2888,9 +2994,21 @@ TEST_P(CypherMainVisitorTest, IncorrectCallProcedure) {
ASSERT_THROW(
ast_generator.ParseQuery("RETURN 42 AS x CALL procedure() YIELD res"),
SemanticException);
ASSERT_THROW(ast_generator.ParseQuery(
"CALL proc.with.dots() YIELD res MEMORY UNLIMITED"),
SyntaxException);
ASSERT_THROW(ast_generator.ParseQuery(
"CALL proc.with.dots() YIELD res MEMORY LIMIT 32 KB"),
SyntaxException);
ASSERT_THROW(
ast_generator.ParseQuery("CALL proc.with.dots() MEMORY YIELD res"),
SyntaxException);
// mg.procedures returns something, so it needs to have a YIELD.
ASSERT_THROW(ast_generator.ParseQuery("CALL mg.procedures()"),
SemanticException);
ASSERT_THROW(
ast_generator.ParseQuery("CALL mg.procedures() MEMORY UNLIMITED"),
SemanticException);
// TODO: Implement support for the following syntax. These are defined in
// Neo4j and accepted in openCypher CIP.
ASSERT_THROW(ast_generator.ParseQuery("CALL proc"), SyntaxException);