Add init-file and init-data-file capabilities (#696)

This commit is contained in:
Marko Budiselić 2022-12-09 18:50:33 +01:00 committed by GitHub
parent f2d5ab61c4
commit 9d6a23b6bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 721 additions and 2 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2021 Memgraph Ltd.
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
@ -12,6 +12,7 @@
#pragma once
#include <filesystem>
#include <fstream>
#include <string>
#include <vector>
@ -61,3 +62,42 @@ inline void LoadConfig(const std::string &product_name) {
for (int i = 0; i < custom_argc; ++i) free(custom_argv[i]);
delete[] custom_argv;
}
std::pair<std::string, std::string> LoadUsernameAndPassword(const std::string &pass_file) {
std::ifstream file(pass_file);
if (file.fail()) {
spdlog::warn("Problem with opening MG_PASSFILE, memgraph server will start without user");
return {};
}
std::vector<std::string> result;
std::string line;
std::getline(file, line);
size_t pos = 0;
std::string token;
static constexpr std::string_view delimiter{":"};
while ((pos = line.find(delimiter)) != std::string::npos) {
if (line[pos - 1] == '\\') {
line.erase(pos - 1, 1);
token += line.substr(0, pos);
line.erase(0, pos);
} else {
token += line.substr(0, pos);
result.push_back(token);
line.erase(0, pos + delimiter.length());
token = "";
}
}
result.push_back(line);
file.close();
if (result.size() != 2) {
spdlog::warn(
"Wrong data format. Data should be store in format: username:password, memgraph server will start without "
"user");
return {};
}
return {result[0], result[1]};
}

View File

@ -26,6 +26,7 @@
#include <string_view>
#include <thread>
#include <fmt/core.h>
#include <fmt/format.h>
#include <gflags/gflags.h>
#include <spdlog/common.h>
@ -98,6 +99,10 @@
#include "audit/log.hpp"
#endif
constexpr const char *kMgUser = "MEMGRAPH_USER";
constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD";
constexpr const char *kMgPassfile = "MEMGRAPH_PASSFILE";
namespace {
std::string GetAllowedEnumValuesString(const auto &mappings) {
std::vector<std::string> allowed_values;
@ -167,6 +172,11 @@ DEFINE_string(bolt_key_file, "", "Key file which should be used for the Bolt ser
DEFINE_string(bolt_server_name_for_init, "",
"Server name which the database should send to the client in the "
"Bolt INIT message.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_string(init_file, "",
"Path to cypherl file that is used for configuring users and database schema before server starts.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_string(init_data_file, "", "Path to cypherl file that is used for creating data after server starts.");
// General purpose flags.
// NOTE: The `data_directory` flag must be the same here and in
@ -476,6 +486,33 @@ struct SessionData {
DEFINE_string(auth_user_or_role_name_regex, memgraph::glue::kDefaultUserRoleRegex.data(),
"Set to the regular expression that each user or role name must fulfill.");
void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string cypherl_file_path
#ifdef MG_ENTERPRISE
,
memgraph::audit::Log *audit_log
#endif
) {
memgraph::query::Interpreter interpreter(&ctx);
std::ifstream file(cypherl_file_path);
if (file.is_open()) {
std::string line;
while (std::getline(file, line)) {
if (!line.empty()) {
auto results = interpreter.Prepare(line, {}, {});
memgraph::query::DiscardValueResultStream stream;
interpreter.Pull(&stream, {}, results.qid);
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
audit_log->Record("", "", line, {});
}
#endif
}
}
file.close();
}
}
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
memgraph::communication::v2::OutputStream> {
public:
@ -889,6 +926,29 @@ int main(int argc, char **argv) {
interpreter_context.auth = &auth_handler;
interpreter_context.auth_checker = &auth_checker;
if (!FLAGS_init_file.empty()) {
spdlog::info("Running init file.");
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
InitFromCypherlFile(interpreter_context, FLAGS_init_file, &audit_log);
}
#else
InitFromCypherlFile(interpreter_context, FLAGS_init_file);
#endif
}
auto *maybe_username = std::getenv(kMgUser);
auto *maybe_password = std::getenv(kMgPassword);
auto *maybe_pass_file = std::getenv(kMgPassfile);
if (maybe_username && maybe_password) {
auth_handler.CreateUser(maybe_username, maybe_password);
} else if (maybe_pass_file) {
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
if (!username.empty() && !password.empty()) {
auth_handler.CreateUser(username, password);
}
}
{
// Triggers can execute query procedures, so we need to reload the modules first and then
// the triggers
@ -966,6 +1026,17 @@ int main(int argc, char **argv) {
MG_ASSERT(server.Start(), "Couldn't start the Bolt server!");
websocket_server.Start();
if (!FLAGS_init_data_file.empty()) {
spdlog::info("Running init data file.");
#ifdef MG_ENTERPRISE
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file, &audit_log);
}
#else
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file);
#endif
}
server.AwaitShutdown();
websocket_server.AwaitShutdown();

View File

@ -163,6 +163,12 @@ startup_config_dict = {
),
"query_max_plans": ("1000", "1000", "Maximum number of generated plans for a query."),
"flag_file": ("", "", "load flags from file"),
"init_file": (
"",
"",
"Path to cypherl file that is used for configuring users and database schema before server starts.",
),
"init_data_file": ("", "", "Path to cypherl file that is used for creating data after server starts."),
"python_submodules_directory": (
"mage",
"mage",

View File

@ -1,7 +1,14 @@
template_cluster: &template_cluster
cluster:
main:
args: ["--log-level=TRACE", "--storage-properties-on-edges=True", "--storage-snapshot-interval-sec", "300", "--storage-wal-enabled=True"]
args:
[
"--log-level=TRACE",
"--storage-properties-on-edges=True",
"--storage-snapshot-interval-sec",
"300",
"--storage-wal-enabled=True",
]
log_file: "configuration-check-e2e.log"
setup_queries: []
validation_queries: []

View File

@ -24,3 +24,9 @@ add_subdirectory(mg_import_csv)
# license_check test binaries
add_subdirectory(license_info)
#environment variable check binaries
add_subdirectory(env_variable_check)
#flag check binaries
add_subdirectory(flag_check)

View File

@ -0,0 +1,7 @@
set(target_name memgraph__integration__env_variable_check)
set(tester_target_name ${target_name}__tester)
set(env_check_target_name ${target_name}__check)
add_executable(${tester_target_name} tester.cpp)
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
target_link_libraries(${tester_target_name} mg-communication)

View File

@ -0,0 +1,151 @@
#!/usr/bin/python3 -u
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import argparse
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import List
SCRIPT_DIR = Path(__file__).absolute()
PROJECT_DIR = SCRIPT_DIR.parents[3]
def wait_for_server(port, delay=0.1):
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
while subprocess.call(cmd) != 0:
time.sleep(0.01)
time.sleep(delay)
def execute_tester(
binary: str,
queries: List[str],
should_fail: bool = False,
failure_message: str = "",
username: str = "",
password: str = "",
check_failure: bool = True,
) -> None:
args = [binary, "--username", username, "--password", password]
if should_fail:
args.append("--should-fail")
if failure_message:
args.extend(["--failure-message", failure_message])
if check_failure:
args.append("--check-failure")
args.extend(queries)
subprocess.run(args).check_returncode()
def start_memgraph(memgraph_args: List[any]) -> subprocess:
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!"
wait_for_server(7687)
return memgraph
def execute_with_user(queries):
return execute_tester(
tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin"
)
def cleanup(memgraph):
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):
return execute_tester(tester_binary, queries, should_fail, failure_message, "", "", check_failure)
def test_without_env_variables(memgraph_args: List[any]) -> None:
memgraph = start_memgraph(memgraph_args)
execute_without_user(["MATCH (n) RETURN n"], False)
cleanup(memgraph)
def test_with_user_password_env_variables(memgraph_args: List[any]) -> None:
os.environ["MEMGRAPH_USER"] = "admin"
os.environ["MEMGRAPH_PASSWORD"] = "admin"
memgraph = start_memgraph(memgraph_args)
execute_with_user(["MATCH (n) RETURN n"])
execute_without_user(["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
cleanup(memgraph)
del os.environ["MEMGRAPH_USER"]
del os.environ["MEMGRAPH_PASSWORD"]
def test_with_passfile_env_variable(storage_directory: tempfile.TemporaryDirectory, memgraph_args: List[any]) -> None:
with open(os.path.join(storage_directory.name, "passfile.txt"), "w") as temp_file:
temp_file.write("admin:admin")
os.environ["MEMGRAPH_PASSFILE"] = storage_directory.name + "/passfile.txt"
memgraph = start_memgraph(memgraph_args)
execute_with_user(["MATCH (n) RETURN n"])
execute_without_user(["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
del os.environ["MEMGRAPH_PASSFILE"]
cleanup(memgraph)
def execute_test(memgraph_binary: str, tester_binary: str) -> None:
storage_directory = tempfile.TemporaryDirectory()
memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name]
return_to_prev_state = {}
if "MEMGRAPH_USER" in os.environ:
return_to_prev_state["MEMGRAPH_USER"] = os.environ["MEMGRAPH_USER"]
del os.environ["MG_USER"]
if "MEMGRAPH_PASSWORD" in os.environ:
return_to_prev_state["MEMGRAPH_PASSWORD"] = os.environ["MEMGRAPH_PASSWORD"]
del os.environ["MEMGRAPH_PASSWORD"]
if "MEMGRAPH_PASSFILE" in os.environ:
return_to_prev_state["MEMGRAPH_PASSFILE"] = os.environ["MEMGRAPH_PASSFILE"]
del os.environ["MEMGRAPH_PASSFILE"]
# Start the memgraph binary
# Run the test with all combinations of permissions
print("\033[1;36m~~ Starting env variable check test ~~\033[0m")
test_without_env_variables(memgraph_args)
test_with_user_password_env_variables(memgraph_args)
test_with_passfile_env_variable(storage_directory, memgraph_args)
print("\033[1;36m~~ Ended env variable check test ~~\033[0m")
if "MEMGRAPH_USER" in return_to_prev_state:
os.environ["MEMGRAPH_USER"] = return_to_prev_state["MEMGRAPH_USER"]
if "MEMGRAPH_PASSWORD" in return_to_prev_state:
os.environ["MEMGRAPH_PASSWORD"] = return_to_prev_state["MEMGRAPH_PASSWORD"]
if "MEMGRAPH_PASSFILE" in return_to_prev_state:
os.environ["MEMGRAPH_PASSFILE"] = return_to_prev_state["MEMGRAPH_PASSFILE"]
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "env_variable_check", "tester")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--tester", default=tester_binary)
args = parser.parse_args()
execute_test(args.memgraph, args.tester)
sys.exit(0)

View File

@ -0,0 +1,94 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include "communication/bolt/client.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
DEFINE_string(failure_message, "", "Set to the expected failure message.");
int ProcessException(const std::string &exception_message) {
if (FLAGS_should_fail) {
if (!FLAGS_failure_message.empty() && exception_message != FLAGS_failure_message) {
LOG_FATAL(
"The query should have failed with an error message of '{}'' but "
"instead it failed with '{}'",
FLAGS_failure_message, exception_message);
}
return 0;
} else {
LOG_FATAL(
"The query shoudn't have failed but it failed with an "
"error message '{}'",
exception_message);
return 1;
}
}
/**
* Executes queries passed as positional arguments and verifies whether they
* succeeded, failed, failed with a specific error message or executed without a
* specific error occurring.
*/
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::communication::SSLInit sslInit;
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(context);
try {
client.Connect(endpoint, FLAGS_username, FLAGS_password);
} catch (const memgraph::utils::BasicException &e) {
return ProcessException(e.what());
}
for (int i = 1; i < argc; ++i) {
std::string query(argv[i]);
try {
client.Execute(query, {});
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
if (!FLAGS_check_failure) {
if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) {
LOG_FATAL(
"The query should have succeeded or failed with an error "
"message that isn't equal to '{}' but it failed with that error "
"message",
FLAGS_failure_message);
}
continue;
}
if (!ProcessException(e.what())) {
return 0;
}
}
if (!FLAGS_check_failure) continue;
if (FLAGS_should_fail) {
LOG_FATAL(
"The query should have failed but instead it executed "
"successfully!");
}
}
return 0;
}

View File

@ -0,0 +1,11 @@
set(target_name memgraph__integration__flag_check)
set(tester_target_name ${target_name}__tester)
set(flag_check_target_name ${target_name}__flag_check)
add_executable(${tester_target_name} tester.cpp)
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
target_link_libraries(${tester_target_name} mg-communication)
add_executable(${flag_check_target_name} flag_check.cpp)
set_target_properties(${flag_check_target_name} PROPERTIES OUTPUT_NAME flag_check)
target_link_libraries(${flag_check_target_name} mg-communication)

View File

@ -0,0 +1,59 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include <cstdlib>
#include "communication/bolt/client.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
#include "utils/logging.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "admin", "Username for the database");
DEFINE_string(password, "admin", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
/**
* Verifies that user 'user' has privileges that are given as positional
* arguments.
*/
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::communication::SSLInit sslInit;
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(context);
client.Connect(endpoint, FLAGS_username, FLAGS_password);
try {
std::string query(argv[1]);
auto ret = client.Execute(query, {});
uint64_t count_got = ret.records.size();
if (count_got != std::atoi(argv[2])) {
LOG_FATAL("Expected the record to have {} entries but they had {} entries!", argv[2], count_got);
}
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
LOG_FATAL(
"The query shoudn't have failed but it failed with an "
"error message '{}', {}",
e.what(), argv[0]);
}
return 0;
}

View File

@ -0,0 +1,173 @@
#!/usr/bin/python3 -u
# Copyright 2022 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.
import argparse
import os
import subprocess
import sys
import tempfile
import time
from typing import List
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
def wait_for_server(port: int, delay: float = 0.1) -> float:
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
while subprocess.call(cmd) != 0:
time.sleep(0.01)
time.sleep(delay)
def execute_tester(
binary: str,
queries: List[str],
should_fail: bool = False,
failure_message: str = "",
username: str = "",
password: str = "",
check_failure: bool = True,
) -> None:
args = [binary, "--username", username, "--password", password]
if should_fail:
args.append("--should-fail")
if failure_message:
args.extend(["--failure-message", failure_message])
if check_failure:
args.append("--check-failure")
args.extend(queries)
subprocess.run(args).check_returncode()
def execute_flag_check(binary: str, queries: List[str], expected: int, username: str = "", password: str = "") -> None:
args = [binary, "--username", username, "--password", password]
args.extend(queries)
args.append(str(expected))
subprocess.run(args).check_returncode()
def start_memgraph(memgraph_args: List[any]) -> subprocess:
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
time.sleep(0.1)
assert memgraph.poll() is None, "Memgraph process died prematurely!"
wait_for_server(7687)
return memgraph
def execute_with_user(tester_binary: str, queries: List[str]) -> None:
return execute_tester(
tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin"
)
def execute_without_user(
tester_binary: str,
queries: List[str],
should_fail: bool = False,
failure_message: str = "",
check_failure: bool = True,
) -> None:
return execute_tester(tester_binary, queries, should_fail, failure_message, "", "", check_failure)
def cleanup(memgraph: subprocess):
if memgraph.poll() is None:
memgraph.terminate()
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
def test_without_any_files(tester_binary: str, memgraph_args: List[str]):
memgraph = start_memgraph(memgraph_args)
execute_without_user(tester_binary, ["MATCH (n) RETURN n"], False)
cleanup(memgraph)
def test_init_file(tester_binary: str, memgraph_args: List[str]):
memgraph = start_memgraph(memgraph_args)
execute_with_user(tester_binary, ["MATCH (n) RETURN n"])
execute_without_user(tester_binary, ["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
cleanup(memgraph)
def test_init_data_file(flag_checker_binary: str, memgraph_args: List[str]):
memgraph = start_memgraph(memgraph_args)
execute_flag_check(flag_checker_binary, ["MATCH (n) RETURN n"], 2, "user", "user")
cleanup(memgraph)
def test_init_and_init_data_file(flag_checker_binary: str, tester_binary: str, memgraph_args: List[str]):
memgraph = start_memgraph(memgraph_args)
execute_with_user(tester_binary, ["MATCH (n) RETURN n"])
execute_without_user(tester_binary, ["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
execute_flag_check(flag_checker_binary, ["MATCH (n) RETURN n"], 2, "user", "user")
cleanup(memgraph)
def execute_test(memgraph_binary: str, tester_binary: str, flag_checker_binary: str) -> None:
storage_directory = tempfile.TemporaryDirectory()
memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name]
# Start the memgraph binary
with open(os.path.join(os.getcwd(), "dummy_init_file.cypherl"), "w") as temp_file:
temp_file.write("CREATE USER admin IDENTIFIED BY 'admin';\n")
temp_file.write("CREATE USER user IDENTIFIED BY 'user';\n")
with open(os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"), "w") as temp_file:
temp_file.write("CREATE (n:RANDOM) RETURN n;\n")
temp_file.write("CREATE (n:RANDOM {name:'1'}) RETURN n;\n")
# Run the test with all combinations of permissions
print("\033[1;36m~~ Starting env variable check test ~~\033[0m")
test_without_any_files(tester_binary, memgraph_args)
memgraph_args_with_init_file = memgraph_args + [
"--init-file",
os.path.join(os.getcwd(), "dummy_init_file.cypherl"),
]
test_init_file(tester_binary, memgraph_args_with_init_file)
memgraph_args_with_init_data_file = memgraph_args + [
"--init-data-file",
os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"),
]
test_init_data_file(flag_checker_binary, memgraph_args_with_init_data_file)
memgraph_args_with_init_file_and_init_data_file = memgraph_args + [
"--init-file",
os.path.join(os.getcwd(), "dummy_init_file.cypherl"),
"--init-data-file",
os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"),
]
test_init_and_init_data_file(flag_checker_binary, tester_binary, memgraph_args_with_init_file_and_init_data_file)
print("\033[1;36m~~ Ended env variable check test ~~\033[0m")
os.remove(os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"))
os.remove(os.path.join(os.getcwd(), "dummy_init_file.cypherl"))
if __name__ == "__main__":
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "flag_check", "tester")
flag_checker_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "flag_check", "flag_check")
parser = argparse.ArgumentParser()
parser.add_argument("--memgraph", default=memgraph_binary)
parser.add_argument("--tester", default=tester_binary)
parser.add_argument("--flag_checker", default=flag_checker_binary)
args = parser.parse_args()
execute_test(args.memgraph, args.tester, args.flag_checker)
sys.exit(0)

View File

@ -0,0 +1,94 @@
// Copyright 2022 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#include <gflags/gflags.h>
#include "communication/bolt/client.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
DEFINE_string(failure_message, "", "Set to the expected failure message.");
int ProcessException(const std::string &exception_message) {
if (FLAGS_should_fail) {
if (!FLAGS_failure_message.empty() && exception_message != FLAGS_failure_message) {
LOG_FATAL(
"The query should have failed with an error message of '{}'' but "
"instead it failed with '{}'",
FLAGS_failure_message, exception_message);
}
return 0;
} else {
LOG_FATAL(
"The query shoudn't have failed but it failed with an "
"error message '{}'",
exception_message);
return 1;
}
}
/**
* Executes queries passed as positional arguments and verifies whether they
* succeeded, failed, failed with a specific error message or executed without a
* specific error occurring.
*/
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
memgraph::communication::SSLInit sslInit;
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
memgraph::communication::ClientContext context(FLAGS_use_ssl);
memgraph::communication::bolt::Client client(context);
try {
client.Connect(endpoint, FLAGS_username, FLAGS_password);
} catch (const memgraph::utils::BasicException &e) {
return ProcessException(e.what());
}
for (int i = 1; i < argc; ++i) {
std::string query(argv[i]);
try {
client.Execute(query, {});
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
if (!FLAGS_check_failure) {
if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) {
LOG_FATAL(
"The query should have succeeded or failed with an error "
"message that isn't equal to '{}' but it failed with that error "
"message",
FLAGS_failure_message);
}
continue;
}
if (!ProcessException(e.what())) {
return 0;
}
}
if (!FLAGS_check_failure) continue;
if (FLAGS_should_fail) {
LOG_FATAL(
"The query should have failed but instead it executed "
"successfully!");
}
}
return 0;
}