diff --git a/tests/manual/CMakeLists.txt b/tests/manual/CMakeLists.txt index 92c6a021b..015123049 100644 --- a/tests/manual/CMakeLists.txt +++ b/tests/manual/CMakeLists.txt @@ -55,6 +55,10 @@ target_link_libraries(${test_prefix}graph_500_generate_snapshot mg-distributed k add_manual_test(ha_client.cpp) target_link_libraries(${test_prefix}ha_client mg-utils mg-communication) +add_manual_test(ha_proxy.cpp) +target_include_directories(${test_prefix}ha_proxy PRIVATE ${CMAKE_BINARY_DIR}/src) +target_link_libraries(${test_prefix}ha_proxy mg-utils mg-communication) + add_manual_test(kvstore_console.cpp) target_link_libraries(${test_prefix}kvstore_console kvstore_lib gflags glog) diff --git a/tests/manual/ha_proxy.cpp b/tests/manual/ha_proxy.cpp new file mode 100644 index 000000000..61e3fb7e0 --- /dev/null +++ b/tests/manual/ha_proxy.cpp @@ -0,0 +1,193 @@ +#include <thread> +#include <vector> + +#include <fmt/format.h> +#include <gflags/gflags.h> +#include <glog/logging.h> + +#include "communication/bolt/ha_client.hpp" +#include "communication/bolt/v1/exceptions.hpp" +#include "communication/bolt/v1/session.hpp" +#include "communication/server.hpp" +#include "io/network/endpoint.hpp" +#include "io/network/utils.hpp" +#include "utils/flag_validation.hpp" +#include "utils/signals.hpp" +#include "utils/string.hpp" +#include "version.hpp" + +DEFINE_string(address, "127.0.0.1", "Proxy server listen address."); +DEFINE_int32(port, 7687, "Proxy server listen port."); +DEFINE_string(cert_file, "", "Proxy server SSL certificate file."); +DEFINE_string(key_file, "", "Proxy server SSL key file."); +DEFINE_VALIDATED_int32(num_workers, + std::max(std::thread::hardware_concurrency(), 1U), + "Proxy server number of workers (Bolt).", + FLAG_IN_RANGE(1, INT32_MAX)); +DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800, + "Proxy server time in seconds after which inactive " + "sessions will be closed.", + FLAG_IN_RANGE(1, INT32_MAX)); + +DEFINE_string(endpoints, "", + "Cluster server endpoints (host:port, separated by comma)."); +DEFINE_bool(use_ssl, true, + "Set to true to connect with SSL to the cluster servers."); + +DEFINE_int32(num_retries, 3, + "Number of retries for each operation (execute/connect)."); +DEFINE_int32(retry_delay_ms, 1000, "Delay before retrying (in ms)."); + +// Global data state that is used by the BoltSession. +struct SessionData { + std::vector<io::network::Endpoint> endpoints; +}; + +class BoltSession final + : public communication::bolt::Session<communication::InputStream, + communication::OutputStream> { + public: + BoltSession(SessionData *data, const io::network::Endpoint &endpoint, + communication::InputStream *input_stream, + communication::OutputStream *output_stream) + : communication::bolt::Session<communication::InputStream, + communication::OutputStream>( + input_stream, output_stream), + session_data_(data), + endpoint_(endpoint), + context_(FLAGS_use_ssl) {} + + using communication::bolt::Session<communication::InputStream, + communication::OutputStream>::TEncoder; + + std::vector<std::string> Interpret( + const std::string &query, + const std::map<std::string, communication::bolt::Value> ¶ms) + override { + records_ = {}; + metadata_ = {}; + try { + auto ret = client_->Execute(query, params); + records_ = std::move(ret.records); + metadata_ = std::move(ret.metadata); + return ret.fields; + } catch (const communication::bolt::ClientQueryException &e) { + // Wrap query exceptions in a client error to indicate to the client that + // it should fix the query and try again. + throw communication::bolt::ClientError(e.what()); + } catch (const communication::bolt::ClientFatalException &e) { + // Wrap fatal exceptions in a verbose error to indcate to the client that + // something is wrong with the database. + throw communication::bolt::VerboseError( + communication::bolt::VerboseError::Classification::DATABASE_ERROR, + "HighAvailability", "Error", e.what()); + } + } + + std::map<std::string, communication::bolt::Value> PullAll( + TEncoder *encoder) override { + for (const auto &record : records_) { + encoder->MessageRecord(record); + } + return metadata_; + } + + void Abort() override { + // Called only for cleanup. + records_.clear(); + metadata_.clear(); + } + + bool Authenticate(const std::string &username, + const std::string &password) override { + client_ = std::make_unique<communication::bolt::HAClient>( + session_data_->endpoints, &context_, username, password, + FLAGS_num_retries, std::chrono::milliseconds(FLAGS_retry_delay_ms), + fmt::format("memgraph_ha_proxy/{}", version_string)); + return true; + } + + private: + SessionData *session_data_; + io::network::Endpoint endpoint_; + + communication::ClientContext context_; + std::unique_ptr<communication::bolt::HAClient> client_; + + std::vector<std::vector<communication::bolt::Value>> records_; + std::map<std::string, communication::bolt::Value> metadata_; +}; + +// Needed to correctly handle proxy destruction from a signal handler. +// Without having some sort of a flag, it is possible that a signal is handled +// when we are exiting main and that would cause a crash. +volatile sig_atomic_t is_shutting_down = 0; + +void InitSignalHandlers(const std::function<void()> &shutdown_fun) { + // Prevent handling shutdown inside a shutdown. For example, SIGINT handler + // being interrupted by SIGTERM before is_shutting_down is set, thus causing + // double shutdown. + sigset_t block_shutdown_signals; + sigemptyset(&block_shutdown_signals); + sigaddset(&block_shutdown_signals, SIGTERM); + sigaddset(&block_shutdown_signals, SIGINT); + + // Wrap the shutdown function in a safe way to prevent recursive shutdown. + auto shutdown = [shutdown_fun]() { + if (is_shutting_down) return; + is_shutting_down = 1; + shutdown_fun(); + }; + + CHECK(utils::SignalHandler::RegisterHandler(utils::Signal::Terminate, + shutdown, block_shutdown_signals)) + << "Unable to register SIGTERM handler!"; + CHECK(utils::SignalHandler::RegisterHandler(utils::Signal::Interupt, shutdown, + block_shutdown_signals)) + << "Unable to register SIGINT handler!"; + + // Setup SIGUSR1 to be used for reopening log files, when e.g. logrotate + // rotates our logs. + CHECK(utils::SignalHandler::RegisterHandler(utils::Signal::User1, []() { + google::CloseLogDestination(google::INFO); + })) << "Unable to register SIGUSR1 handler!"; +} + +std::vector<io::network::Endpoint> GetEndpoints() { + std::vector<io::network::Endpoint> ret; + for (const auto &endpoint : utils::Split(FLAGS_endpoints, ",")) { + auto split = utils::Split(utils::Trim(endpoint), ":"); + CHECK(split.size() == 2) << "Invalid endpoint!"; + ret.emplace_back(io::network::ResolveHostname(utils::Trim(split[0])), + static_cast<uint16_t>(std::stoi(utils::Trim(split[1])))); + } + return ret; +} + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + communication::Init(); + + communication::ServerContext context; + std::string service_name = "Bolt"; + if (!FLAGS_key_file.empty() && !FLAGS_cert_file.empty()) { + context = communication::ServerContext(FLAGS_key_file, FLAGS_cert_file); + service_name = "BoltS"; + } + + SessionData session_data{GetEndpoints()}; + communication::Server<BoltSession, SessionData> server( + {FLAGS_address, static_cast<uint16_t>(FLAGS_port)}, &session_data, + &context, FLAGS_session_inactivity_timeout, service_name, + FLAGS_num_workers); + + // Handler for regular termination signals + InitSignalHandlers([&server] { server.Shutdown(); }); + + CHECK(server.Start()) << "Couldn't start the Bolt server!"; + server.AwaitShutdown(); + + return 0; +}