Add init-file and init-data-file capabilities (#696)
This commit is contained in:
parent
f2d5ab61c4
commit
9d6a23b6bd
@ -1,4 +1,4 @@
|
||||
// Copyright 2021 Memgraph Ltd.
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
@ -12,6 +12,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -61,3 +62,42 @@ inline void LoadConfig(const std::string &product_name) {
|
||||
for (int i = 0; i < custom_argc; ++i) free(custom_argv[i]);
|
||||
delete[] custom_argv;
|
||||
}
|
||||
|
||||
std::pair<std::string, std::string> LoadUsernameAndPassword(const std::string &pass_file) {
|
||||
std::ifstream file(pass_file);
|
||||
if (file.fail()) {
|
||||
spdlog::warn("Problem with opening MG_PASSFILE, memgraph server will start without user");
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> result;
|
||||
|
||||
std::string line;
|
||||
std::getline(file, line);
|
||||
size_t pos = 0;
|
||||
std::string token;
|
||||
static constexpr std::string_view delimiter{":"};
|
||||
while ((pos = line.find(delimiter)) != std::string::npos) {
|
||||
if (line[pos - 1] == '\\') {
|
||||
line.erase(pos - 1, 1);
|
||||
token += line.substr(0, pos);
|
||||
line.erase(0, pos);
|
||||
|
||||
} else {
|
||||
token += line.substr(0, pos);
|
||||
result.push_back(token);
|
||||
line.erase(0, pos + delimiter.length());
|
||||
token = "";
|
||||
}
|
||||
}
|
||||
result.push_back(line);
|
||||
file.close();
|
||||
|
||||
if (result.size() != 2) {
|
||||
spdlog::warn(
|
||||
"Wrong data format. Data should be store in format: username:password, memgraph server will start without "
|
||||
"user");
|
||||
return {};
|
||||
}
|
||||
|
||||
return {result[0], result[1]};
|
||||
}
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <string_view>
|
||||
#include <thread>
|
||||
|
||||
#include <fmt/core.h>
|
||||
#include <fmt/format.h>
|
||||
#include <gflags/gflags.h>
|
||||
#include <spdlog/common.h>
|
||||
@ -98,6 +99,10 @@
|
||||
#include "audit/log.hpp"
|
||||
#endif
|
||||
|
||||
constexpr const char *kMgUser = "MEMGRAPH_USER";
|
||||
constexpr const char *kMgPassword = "MEMGRAPH_PASSWORD";
|
||||
constexpr const char *kMgPassfile = "MEMGRAPH_PASSFILE";
|
||||
|
||||
namespace {
|
||||
std::string GetAllowedEnumValuesString(const auto &mappings) {
|
||||
std::vector<std::string> allowed_values;
|
||||
@ -167,6 +172,11 @@ DEFINE_string(bolt_key_file, "", "Key file which should be used for the Bolt ser
|
||||
DEFINE_string(bolt_server_name_for_init, "",
|
||||
"Server name which the database should send to the client in the "
|
||||
"Bolt INIT message.");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_string(init_file, "",
|
||||
"Path to cypherl file that is used for configuring users and database schema before server starts.");
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
DEFINE_string(init_data_file, "", "Path to cypherl file that is used for creating data after server starts.");
|
||||
|
||||
// General purpose flags.
|
||||
// NOTE: The `data_directory` flag must be the same here and in
|
||||
@ -476,6 +486,33 @@ struct SessionData {
|
||||
DEFINE_string(auth_user_or_role_name_regex, memgraph::glue::kDefaultUserRoleRegex.data(),
|
||||
"Set to the regular expression that each user or role name must fulfill.");
|
||||
|
||||
void InitFromCypherlFile(memgraph::query::InterpreterContext &ctx, std::string cypherl_file_path
|
||||
#ifdef MG_ENTERPRISE
|
||||
,
|
||||
memgraph::audit::Log *audit_log
|
||||
#endif
|
||||
) {
|
||||
memgraph::query::Interpreter interpreter(&ctx);
|
||||
std::ifstream file(cypherl_file_path);
|
||||
if (file.is_open()) {
|
||||
std::string line;
|
||||
while (std::getline(file, line)) {
|
||||
if (!line.empty()) {
|
||||
auto results = interpreter.Prepare(line, {}, {});
|
||||
memgraph::query::DiscardValueResultStream stream;
|
||||
interpreter.Pull(&stream, {}, results.qid);
|
||||
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
audit_log->Record("", "", line, {});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
|
||||
class BoltSession final : public memgraph::communication::bolt::Session<memgraph::communication::v2::InputStream,
|
||||
memgraph::communication::v2::OutputStream> {
|
||||
public:
|
||||
@ -889,6 +926,29 @@ int main(int argc, char **argv) {
|
||||
interpreter_context.auth = &auth_handler;
|
||||
interpreter_context.auth_checker = &auth_checker;
|
||||
|
||||
if (!FLAGS_init_file.empty()) {
|
||||
spdlog::info("Running init file.");
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
InitFromCypherlFile(interpreter_context, FLAGS_init_file, &audit_log);
|
||||
}
|
||||
#else
|
||||
InitFromCypherlFile(interpreter_context, FLAGS_init_file);
|
||||
#endif
|
||||
}
|
||||
|
||||
auto *maybe_username = std::getenv(kMgUser);
|
||||
auto *maybe_password = std::getenv(kMgPassword);
|
||||
auto *maybe_pass_file = std::getenv(kMgPassfile);
|
||||
if (maybe_username && maybe_password) {
|
||||
auth_handler.CreateUser(maybe_username, maybe_password);
|
||||
} else if (maybe_pass_file) {
|
||||
const auto [username, password] = LoadUsernameAndPassword(maybe_pass_file);
|
||||
if (!username.empty() && !password.empty()) {
|
||||
auth_handler.CreateUser(username, password);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Triggers can execute query procedures, so we need to reload the modules first and then
|
||||
// the triggers
|
||||
@ -966,6 +1026,17 @@ int main(int argc, char **argv) {
|
||||
MG_ASSERT(server.Start(), "Couldn't start the Bolt server!");
|
||||
websocket_server.Start();
|
||||
|
||||
if (!FLAGS_init_data_file.empty()) {
|
||||
spdlog::info("Running init data file.");
|
||||
#ifdef MG_ENTERPRISE
|
||||
if (memgraph::license::global_license_checker.IsEnterpriseValidFast()) {
|
||||
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file, &audit_log);
|
||||
}
|
||||
#else
|
||||
InitFromCypherlFile(interpreter_context, FLAGS_init_data_file);
|
||||
#endif
|
||||
}
|
||||
|
||||
server.AwaitShutdown();
|
||||
websocket_server.AwaitShutdown();
|
||||
|
||||
|
@ -163,6 +163,12 @@ startup_config_dict = {
|
||||
),
|
||||
"query_max_plans": ("1000", "1000", "Maximum number of generated plans for a query."),
|
||||
"flag_file": ("", "", "load flags from file"),
|
||||
"init_file": (
|
||||
"",
|
||||
"",
|
||||
"Path to cypherl file that is used for configuring users and database schema before server starts.",
|
||||
),
|
||||
"init_data_file": ("", "", "Path to cypherl file that is used for creating data after server starts."),
|
||||
"python_submodules_directory": (
|
||||
"mage",
|
||||
"mage",
|
||||
|
@ -1,7 +1,14 @@
|
||||
template_cluster: &template_cluster
|
||||
cluster:
|
||||
main:
|
||||
args: ["--log-level=TRACE", "--storage-properties-on-edges=True", "--storage-snapshot-interval-sec", "300", "--storage-wal-enabled=True"]
|
||||
args:
|
||||
[
|
||||
"--log-level=TRACE",
|
||||
"--storage-properties-on-edges=True",
|
||||
"--storage-snapshot-interval-sec",
|
||||
"300",
|
||||
"--storage-wal-enabled=True",
|
||||
]
|
||||
log_file: "configuration-check-e2e.log"
|
||||
setup_queries: []
|
||||
validation_queries: []
|
||||
|
@ -24,3 +24,9 @@ add_subdirectory(mg_import_csv)
|
||||
|
||||
# license_check test binaries
|
||||
add_subdirectory(license_info)
|
||||
|
||||
#environment variable check binaries
|
||||
add_subdirectory(env_variable_check)
|
||||
|
||||
#flag check binaries
|
||||
add_subdirectory(flag_check)
|
||||
|
7
tests/integration/env_variable_check/CMakeLists.txt
Normal file
7
tests/integration/env_variable_check/CMakeLists.txt
Normal file
@ -0,0 +1,7 @@
|
||||
set(target_name memgraph__integration__env_variable_check)
|
||||
set(tester_target_name ${target_name}__tester)
|
||||
set(env_check_target_name ${target_name}__check)
|
||||
|
||||
add_executable(${tester_target_name} tester.cpp)
|
||||
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
|
||||
target_link_libraries(${tester_target_name} mg-communication)
|
151
tests/integration/env_variable_check/runner.py
Normal file
151
tests/integration/env_variable_check/runner.py
Normal file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/python3 -u
|
||||
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
SCRIPT_DIR = Path(__file__).absolute()
|
||||
PROJECT_DIR = SCRIPT_DIR.parents[3]
|
||||
|
||||
|
||||
def wait_for_server(port, delay=0.1):
|
||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||
while subprocess.call(cmd) != 0:
|
||||
time.sleep(0.01)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def execute_tester(
|
||||
binary: str,
|
||||
queries: List[str],
|
||||
should_fail: bool = False,
|
||||
failure_message: str = "",
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
check_failure: bool = True,
|
||||
) -> None:
|
||||
args = [binary, "--username", username, "--password", password]
|
||||
if should_fail:
|
||||
args.append("--should-fail")
|
||||
if failure_message:
|
||||
args.extend(["--failure-message", failure_message])
|
||||
if check_failure:
|
||||
args.append("--check-failure")
|
||||
args.extend(queries)
|
||||
subprocess.run(args).check_returncode()
|
||||
|
||||
|
||||
def start_memgraph(memgraph_args: List[any]) -> subprocess:
|
||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||
time.sleep(0.1)
|
||||
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
||||
wait_for_server(7687)
|
||||
|
||||
return memgraph
|
||||
|
||||
|
||||
def execute_with_user(queries):
|
||||
return execute_tester(
|
||||
tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin"
|
||||
)
|
||||
|
||||
|
||||
def cleanup(memgraph):
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
|
||||
|
||||
def execute_without_user(queries, should_fail=False, failure_message="", check_failure=True):
|
||||
return execute_tester(tester_binary, queries, should_fail, failure_message, "", "", check_failure)
|
||||
|
||||
|
||||
def test_without_env_variables(memgraph_args: List[any]) -> None:
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_without_user(["MATCH (n) RETURN n"], False)
|
||||
cleanup(memgraph)
|
||||
|
||||
|
||||
def test_with_user_password_env_variables(memgraph_args: List[any]) -> None:
|
||||
os.environ["MEMGRAPH_USER"] = "admin"
|
||||
os.environ["MEMGRAPH_PASSWORD"] = "admin"
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_with_user(["MATCH (n) RETURN n"])
|
||||
execute_without_user(["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
|
||||
cleanup(memgraph)
|
||||
del os.environ["MEMGRAPH_USER"]
|
||||
del os.environ["MEMGRAPH_PASSWORD"]
|
||||
|
||||
|
||||
def test_with_passfile_env_variable(storage_directory: tempfile.TemporaryDirectory, memgraph_args: List[any]) -> None:
|
||||
with open(os.path.join(storage_directory.name, "passfile.txt"), "w") as temp_file:
|
||||
temp_file.write("admin:admin")
|
||||
|
||||
os.environ["MEMGRAPH_PASSFILE"] = storage_directory.name + "/passfile.txt"
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_with_user(["MATCH (n) RETURN n"])
|
||||
execute_without_user(["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
|
||||
del os.environ["MEMGRAPH_PASSFILE"]
|
||||
cleanup(memgraph)
|
||||
|
||||
|
||||
def execute_test(memgraph_binary: str, tester_binary: str) -> None:
|
||||
storage_directory = tempfile.TemporaryDirectory()
|
||||
memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name]
|
||||
|
||||
return_to_prev_state = {}
|
||||
if "MEMGRAPH_USER" in os.environ:
|
||||
return_to_prev_state["MEMGRAPH_USER"] = os.environ["MEMGRAPH_USER"]
|
||||
del os.environ["MG_USER"]
|
||||
if "MEMGRAPH_PASSWORD" in os.environ:
|
||||
return_to_prev_state["MEMGRAPH_PASSWORD"] = os.environ["MEMGRAPH_PASSWORD"]
|
||||
del os.environ["MEMGRAPH_PASSWORD"]
|
||||
if "MEMGRAPH_PASSFILE" in os.environ:
|
||||
return_to_prev_state["MEMGRAPH_PASSFILE"] = os.environ["MEMGRAPH_PASSFILE"]
|
||||
del os.environ["MEMGRAPH_PASSFILE"]
|
||||
|
||||
# Start the memgraph binary
|
||||
|
||||
# Run the test with all combinations of permissions
|
||||
print("\033[1;36m~~ Starting env variable check test ~~\033[0m")
|
||||
test_without_env_variables(memgraph_args)
|
||||
test_with_user_password_env_variables(memgraph_args)
|
||||
test_with_passfile_env_variable(storage_directory, memgraph_args)
|
||||
print("\033[1;36m~~ Ended env variable check test ~~\033[0m")
|
||||
|
||||
if "MEMGRAPH_USER" in return_to_prev_state:
|
||||
os.environ["MEMGRAPH_USER"] = return_to_prev_state["MEMGRAPH_USER"]
|
||||
if "MEMGRAPH_PASSWORD" in return_to_prev_state:
|
||||
os.environ["MEMGRAPH_PASSWORD"] = return_to_prev_state["MEMGRAPH_PASSWORD"]
|
||||
if "MEMGRAPH_PASSFILE" in return_to_prev_state:
|
||||
os.environ["MEMGRAPH_PASSFILE"] = return_to_prev_state["MEMGRAPH_PASSFILE"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "env_variable_check", "tester")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
parser.add_argument("--tester", default=tester_binary)
|
||||
args = parser.parse_args()
|
||||
|
||||
execute_test(args.memgraph, args.tester)
|
||||
|
||||
sys.exit(0)
|
94
tests/integration/env_variable_check/tester.cpp
Normal file
94
tests/integration/env_variable_check/tester.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include "communication/bolt/client.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "io/network/utils.hpp"
|
||||
|
||||
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
|
||||
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
|
||||
DEFINE_string(failure_message, "", "Set to the expected failure message.");
|
||||
|
||||
int ProcessException(const std::string &exception_message) {
|
||||
if (FLAGS_should_fail) {
|
||||
if (!FLAGS_failure_message.empty() && exception_message != FLAGS_failure_message) {
|
||||
LOG_FATAL(
|
||||
"The query should have failed with an error message of '{}'' but "
|
||||
"instead it failed with '{}'",
|
||||
FLAGS_failure_message, exception_message);
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
LOG_FATAL(
|
||||
"The query shoudn't have failed but it failed with an "
|
||||
"error message '{}'",
|
||||
exception_message);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Executes queries passed as positional arguments and verifies whether they
|
||||
* succeeded, failed, failed with a specific error message or executed without a
|
||||
* specific error occurring.
|
||||
*/
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
||||
memgraph::communication::SSLInit sslInit;
|
||||
|
||||
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
||||
|
||||
memgraph::communication::ClientContext context(FLAGS_use_ssl);
|
||||
memgraph::communication::bolt::Client client(context);
|
||||
|
||||
try {
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
} catch (const memgraph::utils::BasicException &e) {
|
||||
return ProcessException(e.what());
|
||||
}
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string query(argv[i]);
|
||||
try {
|
||||
client.Execute(query, {});
|
||||
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
|
||||
if (!FLAGS_check_failure) {
|
||||
if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) {
|
||||
LOG_FATAL(
|
||||
"The query should have succeeded or failed with an error "
|
||||
"message that isn't equal to '{}' but it failed with that error "
|
||||
"message",
|
||||
FLAGS_failure_message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!ProcessException(e.what())) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
if (!FLAGS_check_failure) continue;
|
||||
if (FLAGS_should_fail) {
|
||||
LOG_FATAL(
|
||||
"The query should have failed but instead it executed "
|
||||
"successfully!");
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
11
tests/integration/flag_check/CMakeLists.txt
Normal file
11
tests/integration/flag_check/CMakeLists.txt
Normal file
@ -0,0 +1,11 @@
|
||||
set(target_name memgraph__integration__flag_check)
|
||||
set(tester_target_name ${target_name}__tester)
|
||||
set(flag_check_target_name ${target_name}__flag_check)
|
||||
|
||||
add_executable(${tester_target_name} tester.cpp)
|
||||
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
|
||||
target_link_libraries(${tester_target_name} mg-communication)
|
||||
|
||||
add_executable(${flag_check_target_name} flag_check.cpp)
|
||||
set_target_properties(${flag_check_target_name} PROPERTIES OUTPUT_NAME flag_check)
|
||||
target_link_libraries(${flag_check_target_name} mg-communication)
|
59
tests/integration/flag_check/flag_check.cpp
Normal file
59
tests/integration/flag_check/flag_check.cpp
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "communication/bolt/client.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "io/network/utils.hpp"
|
||||
#include "utils/logging.hpp"
|
||||
|
||||
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "admin", "Username for the database");
|
||||
DEFINE_string(password, "admin", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
/**
|
||||
* Verifies that user 'user' has privileges that are given as positional
|
||||
* arguments.
|
||||
*/
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
||||
memgraph::communication::SSLInit sslInit;
|
||||
|
||||
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
||||
|
||||
memgraph::communication::ClientContext context(FLAGS_use_ssl);
|
||||
memgraph::communication::bolt::Client client(context);
|
||||
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
|
||||
try {
|
||||
std::string query(argv[1]);
|
||||
auto ret = client.Execute(query, {});
|
||||
uint64_t count_got = ret.records.size();
|
||||
|
||||
if (count_got != std::atoi(argv[2])) {
|
||||
LOG_FATAL("Expected the record to have {} entries but they had {} entries!", argv[2], count_got);
|
||||
}
|
||||
|
||||
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
|
||||
LOG_FATAL(
|
||||
"The query shoudn't have failed but it failed with an "
|
||||
"error message '{}', {}",
|
||||
e.what(), argv[0]);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
173
tests/integration/flag_check/runner.py
Normal file
173
tests/integration/flag_check/runner.py
Normal file
@ -0,0 +1,173 @@
|
||||
#!/usr/bin/python3 -u
|
||||
|
||||
# Copyright 2022 Memgraph Ltd.
|
||||
#
|
||||
# Use of this software is governed by the Business Source License
|
||||
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
# License, and you may not use this file except in compliance with the Business Source License.
|
||||
#
|
||||
# As of the Change Date specified in that file, in accordance with
|
||||
# the Business Source License, use of this software will be governed
|
||||
# by the Apache License, Version 2.0, included in the file
|
||||
# licenses/APL.txt.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..", ".."))
|
||||
|
||||
|
||||
def wait_for_server(port: int, delay: float = 0.1) -> float:
|
||||
cmd = ["nc", "-z", "-w", "1", "127.0.0.1", str(port)]
|
||||
while subprocess.call(cmd) != 0:
|
||||
time.sleep(0.01)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def execute_tester(
|
||||
binary: str,
|
||||
queries: List[str],
|
||||
should_fail: bool = False,
|
||||
failure_message: str = "",
|
||||
username: str = "",
|
||||
password: str = "",
|
||||
check_failure: bool = True,
|
||||
) -> None:
|
||||
args = [binary, "--username", username, "--password", password]
|
||||
if should_fail:
|
||||
args.append("--should-fail")
|
||||
if failure_message:
|
||||
args.extend(["--failure-message", failure_message])
|
||||
if check_failure:
|
||||
args.append("--check-failure")
|
||||
args.extend(queries)
|
||||
subprocess.run(args).check_returncode()
|
||||
|
||||
|
||||
def execute_flag_check(binary: str, queries: List[str], expected: int, username: str = "", password: str = "") -> None:
|
||||
args = [binary, "--username", username, "--password", password]
|
||||
|
||||
args.extend(queries)
|
||||
args.append(str(expected))
|
||||
|
||||
subprocess.run(args).check_returncode()
|
||||
|
||||
|
||||
def start_memgraph(memgraph_args: List[any]) -> subprocess:
|
||||
memgraph = subprocess.Popen(list(map(str, memgraph_args)))
|
||||
time.sleep(0.1)
|
||||
assert memgraph.poll() is None, "Memgraph process died prematurely!"
|
||||
wait_for_server(7687)
|
||||
|
||||
return memgraph
|
||||
|
||||
|
||||
def execute_with_user(tester_binary: str, queries: List[str]) -> None:
|
||||
return execute_tester(
|
||||
tester_binary, queries, should_fail=False, check_failure=True, username="admin", password="admin"
|
||||
)
|
||||
|
||||
|
||||
def execute_without_user(
|
||||
tester_binary: str,
|
||||
queries: List[str],
|
||||
should_fail: bool = False,
|
||||
failure_message: str = "",
|
||||
check_failure: bool = True,
|
||||
) -> None:
|
||||
return execute_tester(tester_binary, queries, should_fail, failure_message, "", "", check_failure)
|
||||
|
||||
|
||||
def cleanup(memgraph: subprocess):
|
||||
if memgraph.poll() is None:
|
||||
memgraph.terminate()
|
||||
assert memgraph.wait() == 0, "Memgraph process didn't exit cleanly!"
|
||||
|
||||
|
||||
def test_without_any_files(tester_binary: str, memgraph_args: List[str]):
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_without_user(tester_binary, ["MATCH (n) RETURN n"], False)
|
||||
cleanup(memgraph)
|
||||
|
||||
|
||||
def test_init_file(tester_binary: str, memgraph_args: List[str]):
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_with_user(tester_binary, ["MATCH (n) RETURN n"])
|
||||
execute_without_user(tester_binary, ["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
|
||||
cleanup(memgraph)
|
||||
|
||||
|
||||
def test_init_data_file(flag_checker_binary: str, memgraph_args: List[str]):
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_flag_check(flag_checker_binary, ["MATCH (n) RETURN n"], 2, "user", "user")
|
||||
cleanup(memgraph)
|
||||
|
||||
|
||||
def test_init_and_init_data_file(flag_checker_binary: str, tester_binary: str, memgraph_args: List[str]):
|
||||
memgraph = start_memgraph(memgraph_args)
|
||||
execute_with_user(tester_binary, ["MATCH (n) RETURN n"])
|
||||
execute_without_user(tester_binary, ["MATCH (n) RETURN n"], True, "Handshake with the server failed!", True)
|
||||
execute_flag_check(flag_checker_binary, ["MATCH (n) RETURN n"], 2, "user", "user")
|
||||
cleanup(memgraph)
|
||||
|
||||
|
||||
def execute_test(memgraph_binary: str, tester_binary: str, flag_checker_binary: str) -> None:
|
||||
storage_directory = tempfile.TemporaryDirectory()
|
||||
memgraph_args = [memgraph_binary, "--data-directory", storage_directory.name]
|
||||
|
||||
# Start the memgraph binary
|
||||
with open(os.path.join(os.getcwd(), "dummy_init_file.cypherl"), "w") as temp_file:
|
||||
temp_file.write("CREATE USER admin IDENTIFIED BY 'admin';\n")
|
||||
temp_file.write("CREATE USER user IDENTIFIED BY 'user';\n")
|
||||
|
||||
with open(os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"), "w") as temp_file:
|
||||
temp_file.write("CREATE (n:RANDOM) RETURN n;\n")
|
||||
temp_file.write("CREATE (n:RANDOM {name:'1'}) RETURN n;\n")
|
||||
|
||||
# Run the test with all combinations of permissions
|
||||
print("\033[1;36m~~ Starting env variable check test ~~\033[0m")
|
||||
test_without_any_files(tester_binary, memgraph_args)
|
||||
memgraph_args_with_init_file = memgraph_args + [
|
||||
"--init-file",
|
||||
os.path.join(os.getcwd(), "dummy_init_file.cypherl"),
|
||||
]
|
||||
test_init_file(tester_binary, memgraph_args_with_init_file)
|
||||
memgraph_args_with_init_data_file = memgraph_args + [
|
||||
"--init-data-file",
|
||||
os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"),
|
||||
]
|
||||
|
||||
test_init_data_file(flag_checker_binary, memgraph_args_with_init_data_file)
|
||||
memgraph_args_with_init_file_and_init_data_file = memgraph_args + [
|
||||
"--init-file",
|
||||
os.path.join(os.getcwd(), "dummy_init_file.cypherl"),
|
||||
"--init-data-file",
|
||||
os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"),
|
||||
]
|
||||
test_init_and_init_data_file(flag_checker_binary, tester_binary, memgraph_args_with_init_file_and_init_data_file)
|
||||
print("\033[1;36m~~ Ended env variable check test ~~\033[0m")
|
||||
|
||||
os.remove(os.path.join(os.getcwd(), "dummy_init_data_file.cypherl"))
|
||||
os.remove(os.path.join(os.getcwd(), "dummy_init_file.cypherl"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
memgraph_binary = os.path.join(PROJECT_DIR, "build", "memgraph")
|
||||
tester_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "flag_check", "tester")
|
||||
flag_checker_binary = os.path.join(PROJECT_DIR, "build", "tests", "integration", "flag_check", "flag_check")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--memgraph", default=memgraph_binary)
|
||||
parser.add_argument("--tester", default=tester_binary)
|
||||
parser.add_argument("--flag_checker", default=flag_checker_binary)
|
||||
args = parser.parse_args()
|
||||
|
||||
execute_test(args.memgraph, args.tester, args.flag_checker)
|
||||
|
||||
sys.exit(0)
|
94
tests/integration/flag_check/tester.cpp
Normal file
94
tests/integration/flag_check/tester.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
// Copyright 2022 Memgraph Ltd.
|
||||
//
|
||||
// Use of this software is governed by the Business Source License
|
||||
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
|
||||
// License, and you may not use this file except in compliance with the Business Source License.
|
||||
//
|
||||
// As of the Change Date specified in that file, in accordance with
|
||||
// the Business Source License, use of this software will be governed
|
||||
// by the Apache License, Version 2.0, included in the file
|
||||
// licenses/APL.txt.
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
#include "communication/bolt/client.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "io/network/utils.hpp"
|
||||
|
||||
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
DEFINE_bool(check_failure, false, "Set to true to enable failure checking.");
|
||||
DEFINE_bool(should_fail, false, "Set to true to expect a failure.");
|
||||
DEFINE_string(failure_message, "", "Set to the expected failure message.");
|
||||
|
||||
int ProcessException(const std::string &exception_message) {
|
||||
if (FLAGS_should_fail) {
|
||||
if (!FLAGS_failure_message.empty() && exception_message != FLAGS_failure_message) {
|
||||
LOG_FATAL(
|
||||
"The query should have failed with an error message of '{}'' but "
|
||||
"instead it failed with '{}'",
|
||||
FLAGS_failure_message, exception_message);
|
||||
}
|
||||
return 0;
|
||||
} else {
|
||||
LOG_FATAL(
|
||||
"The query shoudn't have failed but it failed with an "
|
||||
"error message '{}'",
|
||||
exception_message);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Executes queries passed as positional arguments and verifies whether they
|
||||
* succeeded, failed, failed with a specific error message or executed without a
|
||||
* specific error occurring.
|
||||
*/
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
|
||||
memgraph::communication::SSLInit sslInit;
|
||||
|
||||
memgraph::io::network::Endpoint endpoint(memgraph::io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
||||
|
||||
memgraph::communication::ClientContext context(FLAGS_use_ssl);
|
||||
memgraph::communication::bolt::Client client(context);
|
||||
|
||||
try {
|
||||
client.Connect(endpoint, FLAGS_username, FLAGS_password);
|
||||
} catch (const memgraph::utils::BasicException &e) {
|
||||
return ProcessException(e.what());
|
||||
}
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string query(argv[i]);
|
||||
try {
|
||||
client.Execute(query, {});
|
||||
} catch (const memgraph::communication::bolt::ClientQueryException &e) {
|
||||
if (!FLAGS_check_failure) {
|
||||
if (!FLAGS_failure_message.empty() && e.what() == FLAGS_failure_message) {
|
||||
LOG_FATAL(
|
||||
"The query should have succeeded or failed with an error "
|
||||
"message that isn't equal to '{}' but it failed with that error "
|
||||
"message",
|
||||
FLAGS_failure_message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!ProcessException(e.what())) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
if (!FLAGS_check_failure) continue;
|
||||
if (FLAGS_should_fail) {
|
||||
LOG_FATAL(
|
||||
"The query should have failed but instead it executed "
|
||||
"successfully!");
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user