Fix older version of SSL (#24)

* Add manual locking callback for SSL

Co-authored-by: Antonio Andelic <antonio.andelic@memgraph.io>
This commit is contained in:
antonio2368 2020-10-20 12:55:13 +02:00 committed by GitHub
parent 291158160d
commit 0a7d4278b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 114 additions and 50 deletions

View File

@ -17,6 +17,9 @@
* Fixed Cypher `ID` function `Null` handling. When the `ID` function receives
`Null`, it will also return `Null`.
* Fixed bug that caused random crashes in SSL communication on platforms
that use older versions of OpenSSL (< 1.1) by adding proper multi-threading
handling.
## v1.1.0

View File

@ -17,8 +17,8 @@ namespace communication {
* It uses blocking sockets and provides an API that can be used to receive/send
* data over the network connection.
*
* NOTE: If you use this client you **must** call `communication::Init()` from
* the `main` function before using the client!
* NOTE: If you use this client you **must** create `communication::SSLInit`
* from the `main` function before using the client!
*/
class Client final {
public:

View File

@ -1,14 +1,57 @@
#include "init.hpp"
#include <glog/logging.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <thread>
#include "utils/signals.hpp"
#include "utils/spin_lock.hpp"
namespace communication {
void Init() {
namespace {
// OpenSSL before 1.1 did not have a out-of-the-box multithreading support
// You need to manually define locks, locking callbacks and id function.
// https://stackoverflow.com/a/42856544
// https://wiki.openssl.org/index.php/Library_Initialization#libssl_Initialization
// https://www.openssl.org/docs/man1.0.2/man3/CRYPTO_num_locks.html
#if OPENSSL_VERSION_NUMBER < 0x10100000L
std::vector<utils::SpinLock> crypto_locks;
void LockingFunction(int mode, int n, const char *file, int line) {
if (mode & CRYPTO_LOCK) {
crypto_locks[n].lock();
} else {
crypto_locks[n].unlock();
}
}
unsigned long IdFunction() {
return (unsigned long)std::hash<std::thread::id>()(
std::this_thread::get_id());
}
void SetupThreading() {
mutex.resize(CRYPTO_num_locks());
CRYPTO_set_id_callback(IdFunction);
CRYPTO_set_locking_callback(LockingFunction);
}
void Cleanup() {
CRYPTO_set_id_callback(nullptr);
CRYPTO_set_locking_callback(nullptr);
mutex.clear();
}
#else
void SetupThreading() {}
void Cleanup() {}
#endif
} // namespace
SSLInit::SSLInit() {
// Initialize the OpenSSL library.
SSL_library_init();
OpenSSL_add_ssl_algorithms();
@ -17,5 +60,9 @@ void Init() {
// Ignore SIGPIPE.
CHECK(utils::SignalIgnore(utils::Signal::Pipe)) << "Couldn't ignore SIGPIPE!";
SetupThreading();
}
SSLInit::~SSLInit() { Cleanup(); }
} // namespace communication

View File

@ -3,14 +3,26 @@
namespace communication {
/**
* Call this function in each `main` file that uses the Communication stack. It
* Create this object in each `main` file that uses the Communication stack. It
* is used to initialize all libraries (primarily OpenSSL) and to fix some
* issues also related to OpenSSL (handling of SIGPIPE).
*
* We define a struct to take advantage of RAII so that the proper cleanup
* is called after we are finished using the SSL connection.
*
* Description of OpenSSL init can be seen here:
* https://wiki.openssl.org/index.php/Library_Initialization
*
* NOTE: This function must be called **exactly** once.
* NOTE: This object must be created **exactly** once.
*/
void Init();
struct SSLInit {
SSLInit();
SSLInit(const SSLInit &) = delete;
SSLInit(SSLInit &&) = delete;
SSLInit &operator=(const SSLInit &) = delete;
SSLInit &operator=(SSLInit &&) = delete;
~SSLInit();
};
} // namespace communication

View File

@ -28,8 +28,8 @@ namespace communication {
* Current Server achitecture:
* incoming connection -> server -> listener -> session
*
* NOTE: If you use this server you **must** call `communication::Init()` from
* the `main` function before using the server!
* NOTE: If you use this server you **must** create `communication::SSLInit`
* from the `main` function before using the server!
*
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure

View File

@ -883,7 +883,7 @@ int main(int argc, char **argv) {
}
// Initialize the communication library.
communication::Init();
communication::SSLInit sslInit;
// Initialize the requests library.
requests::Init();

View File

@ -69,7 +69,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);

View File

@ -19,7 +19,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);

View File

@ -24,7 +24,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);

View File

@ -25,7 +25,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
try {
std::vector<io::network::Endpoint> endpoints(FLAGS_cluster_size);

View File

@ -22,8 +22,9 @@ DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_string(step, "",
"The step to execute (available: create, check, add_node, drop");
DEFINE_int32(property_value, 0, "Value of the property when creating a node.");
DEFINE_int32(expected_status, 0,
"Expected query execution status when creating a node, 0 is success");
DEFINE_int32(
expected_status, 0,
"Expected query execution status when creating a node, 0 is success");
DEFINE_int32(expected_result, 0, "Expected query result");
using namespace std::chrono_literals;
@ -32,7 +33,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
try {
std::vector<io::network::Endpoint> endpoints;
@ -70,8 +71,8 @@ int main(int argc, char **argv) {
if (result.records.size() != FLAGS_expected_result) {
LOG(WARNING) << "Unexpected number of nodes: "
<< "expected " << FLAGS_expected_result
<< ", got " << result.records.size();
<< "expected " << FLAGS_expected_result << ", got "
<< result.records.size();
return 2;
}
return 0;
@ -94,6 +95,5 @@ int main(int argc, char **argv) {
LOG(WARNING) << "Error while executing query.";
}
return 1;
}

View File

@ -28,7 +28,7 @@ int main(int argc, char **argv) {
const std::string index = ":Node(id)";
communication::Init();
communication::SSLInit sslInit;
try {
std::vector<io::network::Endpoint> endpoints(FLAGS_cluster_size);
for (int i = 0; i < FLAGS_cluster_size; ++i)

View File

@ -27,7 +27,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
try {
std::vector<io::network::Endpoint> endpoints(FLAGS_cluster_size);
for (int i = 0; i < FLAGS_cluster_size; ++i)

View File

@ -24,7 +24,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
try {
std::vector<io::network::Endpoint> endpoints(FLAGS_cluster_size);
for (int i = 0; i < FLAGS_cluster_size; ++i)

View File

@ -27,7 +27,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
try {
std::vector<io::network::Endpoint> endpoints;
@ -52,8 +52,8 @@ int main(int argc, char **argv) {
if (result.records.size() != FLAGS_expected_results) {
LOG(WARNING) << "Unexpected number of nodes: "
<< "expected " << FLAGS_expected_results
<< ", got " << result.records.size();
<< "expected " << FLAGS_expected_results << ", got "
<< result.records.size();
return 2;
}
return 0;

View File

@ -25,7 +25,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);

View File

@ -20,7 +20,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);

View File

@ -46,7 +46,7 @@ int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
// Initialize the communication stack.
communication::Init();
communication::SSLInit sslInit;
// Initialize the server.
EchoData echo_data;

View File

@ -478,7 +478,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
return RUN_ALL_TESTS();
}

View File

@ -104,7 +104,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);

View File

@ -319,7 +319,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
Endpoint endpoint(FLAGS_address, FLAGS_port);
ClientContext context(FLAGS_use_ssl);
@ -346,10 +346,10 @@ int main(int argc, char **argv) {
CHECK(FLAGS_num_workers >= 2)
<< "There should be at least 2 client workers (analytic and cleanup)";
CHECK(num_pos == config["num_workers"].get<int>() *
config["pos_per_worker"].get<int>())
config["pos_per_worker"].get<int>())
<< "Wrong number of POS per worker";
CHECK(num_cards == config["num_workers"].get<int>() *
config["cards_per_worker"].get<int>())
config["cards_per_worker"].get<int>())
<< "Wrong number of cards per worker";
for (int i = 0; i < FLAGS_num_workers - 2; ++i) {
clients.emplace_back(std::make_unique<CardFraudClient>(i, config));

View File

@ -272,7 +272,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
nlohmann::json config;
std::cin >> config;

View File

@ -6,6 +6,7 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/init.hpp"
#include "utils/algorithm.hpp"
#include "utils/spin_lock.hpp"
#include "utils/string.hpp"
@ -101,7 +102,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
std::ifstream ifile;
std::istream *istream{&std::cin};

View File

@ -18,7 +18,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
// TODO: handle endpoint exception
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),

View File

@ -28,7 +28,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
std::vector<io::network::Endpoint> endpoints;
endpoints.reserve(FLAGS_cluster_size);

View File

@ -182,7 +182,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
communication::ServerContext context;
std::string service_name = "Bolt";

View File

@ -38,7 +38,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);

View File

@ -56,7 +56,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
// Initialize the server.
EchoData echo_data;

View File

@ -15,6 +15,7 @@
#include "communication/bolt/client.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/init.hpp"
#include "utils/exceptions.hpp"
#include "utils/string.hpp"
#include "utils/timer.hpp"
@ -229,7 +230,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
std::ifstream ifile;
std::istream *istream{&std::cin};

View File

@ -109,7 +109,7 @@ class GraphSession {
bool Bernoulli(double p) { return GetRandom() < p; }
template<typename T>
template <typename T>
T RandomElement(const std::set<T> &data) {
uint32_t pos = std::floor(GetRandom() * data.size());
auto it = data.begin();
@ -380,7 +380,7 @@ int main(int argc, char **argv) {
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";
communication::Init();
communication::SSLInit sslInit;
LOG(INFO) << "Starting Memgraph HA normal operation long running test";
@ -397,7 +397,8 @@ int main(int argc, char **argv) {
// cleanup and create indexes
client.Execute("MATCH (n) DETACH DELETE n", {});
for (int i = 0; i < FLAGS_worker_count; ++i) {
client.Execute(fmt::format("CREATE INDEX ON :indexed_label{}(id)", i), {});
client.Execute(fmt::format("CREATE INDEX ON :indexed_label{}(id)", i),
{});
}
} catch (const communication::bolt::ClientFatalException &e) {
LOG(WARNING) << "Unable to find cluster leader";
@ -415,8 +416,7 @@ int main(int argc, char **argv) {
// sessions
std::vector<GraphSession> sessions;
sessions.reserve(FLAGS_worker_count);
for (int i = 0; i < FLAGS_worker_count; ++i)
sessions.emplace_back(i);
for (int i = 0; i < FLAGS_worker_count; ++i) sessions.emplace_back(i);
// workers
std::vector<std::thread> threads;
@ -424,8 +424,7 @@ int main(int argc, char **argv) {
for (int i = 0; i < FLAGS_worker_count; ++i)
threads.emplace_back([&, i]() { sessions[i].Run(); });
for (int i = 0; i < FLAGS_worker_count; ++i)
threads[i].join();
for (int i = 0; i < FLAGS_worker_count; ++i) threads[i].join();
if (!FLAGS_stats_file.empty()) {
uint64_t executed = 0;

View File

@ -415,7 +415,7 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";

View File

@ -12,6 +12,7 @@
#include <unordered_set>
#include "communication/bolt/client.hpp"
#include "communication/init.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/utils.hpp"
#include "utils/algorithm.hpp"
@ -550,7 +551,7 @@ int main(int argc, char **argv) {
FLAGS_min_log_level = google::ERROR;
google::InitGoogleLogging(argv[0]);
communication::Init();
communication::SSLInit sslInit;
#ifdef HAS_READLINE
using_history();