From 1d448d40ca19e2eb5a18d27e6fa055b0bd0db520 Mon Sep 17 00:00:00 2001 From: Matej Ferencevic Date: Wed, 20 Jun 2018 17:44:47 +0200 Subject: [PATCH] Implement SSL support for servers and clients Summary: This diff implements OpenSSL support in the network stack. Currently SSL support is only enabled for Bolt connections, support for RPC connections will be added in another diff. Reviewers: buda, teon.banek Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1328 --- CHANGELOG.md | 1 + CMakeLists.txt | 7 + config/community.conf | 6 + docs/user_technical/drivers.md | 41 ++-- docs/user_technical/quick-start.md | 4 +- init | 1 + release/debian/postinst | 10 + release/rpm/memgraph.spec.in | 10 + src/CMakeLists.txt | 6 + src/communication/bolt/client.hpp | 4 +- src/communication/buffer.hpp | 17 +- src/communication/client.cpp | 232 ++++++++++++++++++ src/communication/client.hpp | 117 +++++---- src/communication/context.cpp | 80 ++++++ src/communication/context.hpp | 63 +++++ src/communication/helpers.cpp | 13 + src/communication/helpers.hpp | 12 + src/communication/init.cpp | 21 ++ src/communication/init.hpp | 16 ++ src/communication/listener.hpp | 44 ++-- src/communication/rpc/client.cpp | 2 +- src/communication/rpc/client.hpp | 2 + src/communication/rpc/server.cpp | 2 +- src/communication/rpc/server.hpp | 2 + src/communication/server.hpp | 27 +- src/communication/session.hpp | 210 ++++++++++++++-- src/io/network/socket.cpp | 37 ++- src/io/network/socket.hpp | 36 ++- src/memgraph_bolt.cpp | 30 ++- tests/concurrent/network_common.hpp | 5 +- tests/concurrent/network_read_hang.cpp | 9 +- tests/concurrent/network_server.cpp | 3 +- tests/concurrent/network_session_leak.cpp | 3 +- tests/integration/CMakeLists.txt | 3 + tests/integration/apollo_runs.yaml | 8 + tests/integration/ssl/CMakeLists.txt | 6 + tests/integration/ssl/runner.sh | 78 ++++++ tests/integration/ssl/tester.cpp | 76 ++++++ tests/macro_benchmark/clients/bolt_client.hpp | 15 +- .../clients/card_fraud_client.cpp | 5 +- tests/macro_benchmark/clients/common.hpp | 1 + .../clients/long_running_common.hpp | 4 +- .../macro_benchmark/clients/pokec_client.cpp | 5 +- .../macro_benchmark/clients/query_client.cpp | 6 +- tests/manual/CMakeLists.txt | 6 + tests/manual/bolt_client.cpp | 7 +- tests/manual/ssl_client.cpp | 65 +++++ tests/manual/ssl_server.cpp | 73 ++++++ tests/stress/.gitignore | 1 + tests/stress/apollo_runs.yaml | 4 + tests/stress/common.py | 13 +- tests/stress/continuous_integration | 21 ++ tests/stress/long_running.cpp | 10 +- tests/unit/network_timeouts.cpp | 3 +- tests/unit/socket.cpp | 94 +++++++ 55 files changed, 1400 insertions(+), 177 deletions(-) create mode 100644 src/communication/client.cpp create mode 100644 src/communication/context.cpp create mode 100644 src/communication/context.hpp create mode 100644 src/communication/helpers.cpp create mode 100644 src/communication/helpers.hpp create mode 100644 src/communication/init.cpp create mode 100644 src/communication/init.hpp create mode 100644 tests/integration/ssl/CMakeLists.txt create mode 100755 tests/integration/ssl/runner.sh create mode 100644 tests/integration/ssl/tester.cpp create mode 100644 tests/manual/ssl_client.cpp create mode 100644 tests/manual/ssl_server.cpp create mode 100644 tests/unit/socket.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index c3ee456fe..dfffa666b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ * Static vertices/edges id generators exposed through the Id Cypher function. * Properties on disk added. * Telemetry added. +* SSL support added. * Add `toString` function to openCypher ### Bug Fixes and Other Changes diff --git a/CMakeLists.txt b/CMakeLists.txt index e09b9b919..ecf5efb42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,6 +140,9 @@ endif() set(Boost_USE_STATIC_LIBS ON) find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization) +# OpenSSL +find_package(OpenSSL REQUIRED) + set(libs_dir ${CMAKE_SOURCE_DIR}/libs) add_subdirectory(libs EXCLUDE_FROM_ALL) @@ -320,6 +323,8 @@ set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY} Contains Memgraph, the graph database. It aims to deliver developers the speed, simplicity and scale required to build the next generation of applications driver by real-time connected data.") +# Add `openssl` package to dependencies list. Used to generate SSL certificates. +set(CPACK_DEBIAN_memgraph_PACKAGE_DEPENDS "openssl (>= 1.1.0)") # RPM specific set(CPACK_RPM_PACKAGE_URL https://memgraph.com) @@ -335,6 +340,8 @@ set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_SOURCE_DIR}/release/rpm/memgraph.spe set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database. It aims to deliver developers the speed, simplicity and scale required to build the next generation of applications driver by real-time connected data.") +# Add `openssl` package to dependencies list. Used to generate SSL certificates. +set(CPACK_RPM_memgraph_PACKAGE_REQUIRES "openssl >= 1.0.0") # All variables must be set before including. include(CPack) diff --git a/config/community.conf b/config/community.conf index 0901d794e..61e17eb00 100644 --- a/config/community.conf +++ b/config/community.conf @@ -16,6 +16,12 @@ # Port the server should listen on. --port=7687 +# Path to a SSL certificate file that should be used. +--cert-file=/etc/memgraph/ssl/cert.pem + +# Path to a SSL key file that should be used. +--key-file=/etc/memgraph/ssl/key.pem + # Number of workers used by the Memgraph server. By default, this will be the # number of processing units available on the machine. # --num-workers=8 diff --git a/docs/user_technical/drivers.md b/docs/user_technical/drivers.md index 0a6295ebc..09fd3958e 100644 --- a/docs/user_technical/drivers.md +++ b/docs/user_technical/drivers.md @@ -13,11 +13,9 @@ from neo4j.v1 import GraphDatabase, basic_auth # Initialize and configure the driver. # * provide the correct URL where Memgraph is reachable; -# * use an empty user name and password, and -# * disable encryption (not supported). +# * use an empty user name and password. driver = GraphDatabase.driver("bolt://localhost:7687", - auth=basic_auth("", ""), - encrypted=False) + auth=basic_auth("", "")) # Start a session in which queries are executed. session = driver.session() @@ -51,9 +49,7 @@ The details about Java driver can be found [on GitHub](https://github.com/neo4j/neo4j-java-driver). The example below is equivalent to Python example. Major difference is that -`Config` object has to be created before the driver construction. Encryption -has to be disabled by calling `withoutEncryption` method against the `Config` -builder. +`Config` object has to be created before the driver construction. ```java import org.neo4j.driver.v1.*; @@ -64,7 +60,7 @@ import java.util.*; public class JavaQuickStart { public static void main(String[] args) { // Initialize driver. - Config config = Config.build().withoutEncryption().toConfig(); + Config config = Config.build().toConfig(); Driver driver = GraphDatabase.driver("bolt://localhost:7687", AuthTokens.basic("",""), config); @@ -93,9 +89,7 @@ public class JavaQuickStart { The details about Javascript driver can be found [on GitHub](https://github.com/neo4j/neo4j-javascript-driver). -The Javascript example below is equivalent to Python and Java examples. SSL -can be disabled by passing `{encrypted: 'ENCRYPTION_OFF'}` during the driver -construction. +The Javascript example below is equivalent to Python and Java examples. Here is an example related to `Node.js`. Memgraph doesn't have integrated support for `WebSocket` which is required during the execution in any web @@ -109,8 +103,7 @@ proxy port. ```javascript var neo4j = require('neo4j-driver').v1; var driver = neo4j.driver("bolt://localhost:7687", - neo4j.auth.basic("neo4j", "1234"), - {encrypted: 'ENCRYPTION_OFF'}); + neo4j.auth.basic("neo4j", "1234")); var session = driver.session(); function die() { @@ -146,8 +139,7 @@ run_query("MATCH (n) DETACH DELETE n", function (result) { The C# driver is hosted [on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below -performs the same work as all of the previous examples. Encryption is disabled -by setting `EncryptionLevel.NONE` on the `Config`. +performs the same work as all of the previous examples. ```csh using System; @@ -158,7 +150,6 @@ public class Basic { public static void Main(string[] args) { // Initialize the driver. var config = Config.DefaultConfig; - config.EncryptionLevel = EncryptionLevel.None; using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config)) using(var session = driver.Session()) { @@ -176,6 +167,18 @@ public class Basic { } ``` +### Secure Sockets Layer (SSL) + +Secure connections are supported and enabled by default. The server initially +ships with a self-signed testing certificate. The certificate can be replaced +by editing the following parameters in `/etc/memgraph/memgraph.conf`: +``` +--cert-file=/path/to/ssl/certificate.pem +--key-file=/path/to/ssl/privatekey.pem +``` +To disable SSL support and use insecure connections to the database you should +set both parameters (`--cert-file` and `--key-file`) to empty values. + ### Limitations Memgraph is currently in early stage, and has a number of limitations we plan @@ -186,9 +189,3 @@ to remove in future versions. Memgraph is currently single-user only. There is no way to control user privileges. The default user has read and write privileges over the whole database. - -#### Secure Sockets Layer (SSL) - -Secure connections are not supported. For this reason each client -driver needs to be configured not to use encryption. Consult driver-specific -guides for details. diff --git a/docs/user_technical/quick-start.md b/docs/user_technical/quick-start.md index 5ec525520..7c72a366b 100644 --- a/docs/user_technical/quick-start.md +++ b/docs/user_technical/quick-start.md @@ -183,7 +183,7 @@ After installing `neo4j-client`, connect to the running Memgraph instance by issuing the following shell command. ```bash -neo4j-client --insecure -u "" -p "" localhost 7687 +neo4j-client -u "" -p "" localhost 7687 ``` After the client has started it should present a command prompt similar to: @@ -191,7 +191,7 @@ After the client has started it should present a command prompt similar to: ```bash neo4j-client 2.1.3 Enter `:help` for usage hints. -Connected to 'neo4j://@localhost:7687' (insecure) +Connected to 'neo4j://@localhost:7687' neo4j> ``` diff --git a/init b/init index 1e87833c9..31abc6fe0 100755 --- a/init +++ b/init @@ -8,6 +8,7 @@ required_pkgs=(git arcanist # source code control curl wget # for downloading libs uuid-dev default-jre-headless # required by antlr libreadline-dev # for memgraph console + libssl-dev libboost-iostreams-dev libboost-serialization-dev python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests diff --git a/release/debian/postinst b/release/debian/postinst index b589521ce..c078c7605 100644 --- a/release/debian/postinst +++ b/release/debian/postinst @@ -29,6 +29,16 @@ case "$1" in chmod 750 /var/log/memgraph || exit 1 # Make examples directory immutable (optional) chattr +i -R /usr/share/memgraph/examples || true + + # Generate SSL certificates + if [ ! -d /etc/memgraph/ssl ]; then + mkdir /etc/memgraph/ssl || exit 1 + openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \ + -keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \ + -subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1 + chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1 + chmod 400 /etc/memgraph/ssl/* || exit 1 + fi ;; abort-upgrade|abort-remove|abort-deconfigure) diff --git a/release/rpm/memgraph.spec.in b/release/rpm/memgraph.spec.in index c6d3c4d53..500705548 100644 --- a/release/rpm/memgraph.spec.in +++ b/release/rpm/memgraph.spec.in @@ -71,6 +71,16 @@ chown memgraph:adm /var/log/memgraph || exit 1 chmod 750 /var/log/memgraph || exit 1 # Make examples directory immutable (optional) chattr +i -R /usr/share/memgraph/examples || true + +# Generate SSL certificates +if [ ! -d /etc/memgraph/ssl ]; then + mkdir /etc/memgraph/ssl || exit 1 + openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \ + -keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \ + -subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1 + chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1 + chmod 400 /etc/memgraph/ssl/* || exit 1 +fi @RPM_SYMLINK_POSTINSTALL@ @CPACK_RPM_SPEC_POSTINSTALL@ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 313f8d53c..54c51ab23 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,6 +8,10 @@ add_subdirectory(telemetry) # all memgraph src files set(memgraph_src_files communication/buffer.cpp + communication/client.cpp + communication/context.cpp + communication/helpers.cpp + communication/init.cpp communication/bolt/v1/decoder/decoded_value.cpp communication/rpc/client.cpp communication/rpc/protocol.cpp @@ -189,6 +193,7 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) # memgraph_lib depend on these libraries set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools antlr_opencypher_parser_lib dl glog gflags capnp kj + ${OPENSSL_LIBRARIES} ${Boost_IOSTREAMS_LIBRARY_RELEASE} ${Boost_SERIALIZATION_LIBRARY_RELEASE} mg-utils mg-io) @@ -206,6 +211,7 @@ endif() # STATIC library used by memgraph executables add_library(memgraph_lib STATIC ${memgraph_src_files}) target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS}) +target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR}) add_dependencies(memgraph_lib generate_opencypher_parser) add_dependencies(memgraph_lib generate_lcp) add_dependencies(memgraph_lib generate_capnp) diff --git a/src/communication/bolt/client.hpp b/src/communication/bolt/client.hpp index c5a8773dc..70aa7a396 100644 --- a/src/communication/bolt/client.hpp +++ b/src/communication/bolt/client.hpp @@ -32,9 +32,9 @@ struct QueryData { std::map metadata; }; -class Client { +class Client final { public: - Client() {} + explicit Client(communication::ClientContext *context) : client_(context) {} Client(const Client &) = delete; Client(Client &&) = delete; diff --git a/src/communication/buffer.hpp b/src/communication/buffer.hpp index 6351db0c6..feb4d0cf4 100644 --- a/src/communication/buffer.hpp +++ b/src/communication/buffer.hpp @@ -20,7 +20,7 @@ namespace communication { * stack where all execution when it is being done is being done on a single * thread. */ -class Buffer { +class Buffer final { private: // Initial capacity of the internal buffer. const size_t kBufferInitialSize = 65536; @@ -28,6 +28,11 @@ class Buffer { public: Buffer(); + Buffer(const Buffer &) = delete; + Buffer(Buffer &&) = delete; + Buffer &operator=(const Buffer &) = delete; + Buffer &operator=(Buffer &&) = delete; + /** * This class provides all functions from the buffer that are needed to allow * reading data from the buffer. @@ -36,6 +41,11 @@ class Buffer { public: ReadEnd(Buffer &buffer); + ReadEnd(const ReadEnd &) = delete; + ReadEnd(ReadEnd &&) = delete; + ReadEnd &operator=(const ReadEnd &) = delete; + ReadEnd &operator=(ReadEnd &&) = delete; + uint8_t *data(); size_t size() const; @@ -58,6 +68,11 @@ class Buffer { public: WriteEnd(Buffer &buffer); + WriteEnd(const WriteEnd &) = delete; + WriteEnd(WriteEnd &&) = delete; + WriteEnd &operator=(const WriteEnd &) = delete; + WriteEnd &operator=(WriteEnd &&) = delete; + io::network::StreamBuffer Allocate(); void Written(size_t len); diff --git a/src/communication/client.cpp b/src/communication/client.cpp new file mode 100644 index 000000000..6822f37cf --- /dev/null +++ b/src/communication/client.cpp @@ -0,0 +1,232 @@ +#include + +#include "communication/client.hpp" +#include "communication/helpers.hpp" + +namespace communication { + +Client::Client(ClientContext *context) : context_(context) {} + +Client::~Client() { + Close(); + ReleaseSslObjects(); +} + +bool Client::Connect(const io::network::Endpoint &endpoint) { + // Try to establish a socket connection. + if (!socket_.Connect(endpoint)) return false; + + // Enable TCP keep alive for all connections. + // Because we manually always set the `have_more` flag to the socket + // `Write` call we can disable the Nagle algorithm because we know that we + // are always sending optimal packets. Even if we don't send optimal + // packets, there will be no delay between packets and throughput won't + // suffer. + socket_.SetKeepAlive(); + socket_.SetNoDelay(); + + if (context_->use_ssl()) { + // Release leftover SSL objects. + ReleaseSslObjects(); + + // Create a new SSL object that will be used for SSL communication. + ssl_ = SSL_new(context_->context()); + if (ssl_ == nullptr) { + LOG(WARNING) << "Couldn't create client SSL object!"; + socket_.Close(); + return false; + } + + // Create a new BIO (block I/O) SSL object so that OpenSSL can communicate + // using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that + // it doesn't need to close the socket when destructing all objects (we + // handle that in our socket destructor). + bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE); + if (bio_ == nullptr) { + LOG(WARNING) << "Couldn't create client BIO object!"; + socket_.Close(); + return false; + } + + // Connect the BIO object to the SSL object so that OpenSSL knows which + // stream it should use for communication. We use the same object for both + // the read and write end. This function cannot fail. + SSL_set_bio(ssl_, bio_, bio_); + + // Clear all leftover errors. + ERR_clear_error(); + + // Perform the TLS handshake. + auto ret = SSL_connect(ssl_); + if (ret != 1) { + LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError(); + socket_.Close(); + return false; + } + } + + return true; +} + +bool Client::ErrorStatus() { return socket_.ErrorStatus(); } + +void Client::Shutdown() { socket_.Shutdown(); } + +void Client::Close() { + if (ssl_) { + // Perform an unidirectional SSL shutdown. That just means that we send + // the shutdown message and don't wait or care for the result. + SSL_shutdown(ssl_); + } + socket_.Close(); +} + +bool Client::Read(size_t len) { + size_t received = 0; + buffer_.write_end().Resize(buffer_.read_end().size() + len); + while (received < len) { + auto buff = buffer_.write_end().Allocate(); + if (ssl_) { + // We clear errors here to prevent errors piling up in the internal + // OpenSSL error queue. To see when could that be an issue read this: + // https://www.arangodb.com/2014/07/started-hate-openssl/ + ERR_clear_error(); + + // Read encrypted data from the socket using OpenSSL. + auto got = SSL_read(ssl_, buff.data, len - received); + + // Handle errors that might have occurred. + if (got < 0) { + auto err = SSL_get_error(ssl_, got); + if (err == SSL_ERROR_WANT_READ) { + // OpenSSL want's to read more data from the socket. We wait for + // more data to be ready and retry the call. + socket_.WaitForReadyRead(); + continue; + } else if (err == SSL_ERROR_WANT_WRITE) { + // The OpenSSL library probably wants to perform some kind of + // handshake so we wait for the socket to become ready for a write + // and call the read again. + socket_.WaitForReadyWrite(); + continue; + } else { + // This is a fatal error. + LOG(WARNING) << "Received an unexpected SSL error: " << err; + return false; + } + } else if (got == 0) { + // The server closed the connection. + return false; + } + + // Notify the buffer that it has new data. + buffer_.write_end().Written(got); + received += got; + } else { + // Read raw data from the socket. + auto got = socket_.Read(buff.data, len - received); + + if (got <= 0) { + // If `read` returns 0 the server has closed the connection. If `read` + // returns -1 all of the errors that could be found in `errno` are + // fatal errors (because we are using a blocking socket) so return a + // read failure. + return false; + } + + // Notify the buffer that it has new data. + buffer_.write_end().Written(got); + received += got; + } + } + return true; +} + +uint8_t *Client::GetData() { return buffer_.read_end().data(); } + +size_t Client::GetDataSize() { return buffer_.read_end().size(); } + +void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); } + +void Client::ClearData() { buffer_.read_end().Clear(); } + +bool Client::Write(const uint8_t *data, size_t len, bool have_more) { + if (ssl_) { + // `SSL_write` has the interface of a normal `write` call. Because of that + // we need to ensure that all data is written to the socket manually. + while (len > 0) { + // We clear errors here to prevent errors piling up in the internal + // OpenSSL error queue. To see when could that be an issue read this: + // https://www.arangodb.com/2014/07/started-hate-openssl/ + ERR_clear_error(); + + // Write data to the socket using OpenSSL. + auto written = SSL_write(ssl_, data, len); + if (written < 0) { + auto err = SSL_get_error(ssl_, written); + if (err == SSL_ERROR_WANT_READ) { + // OpenSSL wants to perform some kind of handshake, we need to + // ensure that there is data available for the next call to + // `SSL_write`. + socket_.WaitForReadyRead(); + } else if (err == SSL_ERROR_WANT_WRITE) { + // The socket probably returned WOULDBLOCK and we need to wait for + // the output buffers to clear and reattempt the send. + socket_.WaitForReadyWrite(); + } else { + // This is a fatal error. + return false; + } + } else if (written == 0) { + // The client closed the connection. + return false; + } else { + len -= written; + data += written; + } + } + return true; + } else { + return socket_.Write(data, len, have_more); + } +} + +bool Client::Write(const std::string &str, bool have_more) { + return Write(reinterpret_cast(str.data()), str.size(), + have_more); +} + +const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); } + +void Client::ReleaseSslObjects() { + // If we are using SSL we need to free the allocated objects. Here we only + // free the SSL object because the `SSL_free` function also automatically + // frees the BIO object. + if (ssl_) { + SSL_free(ssl_); + ssl_ = nullptr; + bio_ = nullptr; + } +} + +ClientInputStream::ClientInputStream(Client &client) : client_(client) {} + +uint8_t *ClientInputStream::data() { return client_.GetData(); } + +size_t ClientInputStream::size() const { return client_.GetDataSize(); } + +void ClientInputStream::Shift(size_t len) { client_.ShiftData(len); } + +void ClientInputStream::Clear() { client_.ClearData(); } + +ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {} + +bool ClientOutputStream::Write(const uint8_t *data, size_t len, + bool have_more) { + return client_.Write(data, len, have_more); +} +bool ClientOutputStream::Write(const std::string &str, bool have_more) { + return client_.Write(str, have_more); +} + +} // namespace communication diff --git a/src/communication/client.hpp b/src/communication/client.hpp index 771cb464e..5f6e35906 100644 --- a/src/communication/client.hpp +++ b/src/communication/client.hpp @@ -1,6 +1,12 @@ #pragma once +#include +#include +#include + #include "communication/buffer.hpp" +#include "communication/context.hpp" +#include "communication/init.hpp" #include "io/network/endpoint.hpp" #include "io/network/socket.hpp" @@ -10,106 +16,115 @@ namespace communication { * This class implements a generic network Client. * It uses blocking sockets and provides an API that can be used to receive/send * data over the network connection. + * + * NOTE: If you use this client you **must** call `communication::Init()` from + * the `main` function before using the client! */ -class Client { +class Client final { public: + explicit Client(ClientContext *context); + + ~Client(); + + Client(const Client &) = delete; + Client(Client &&) = delete; + Client &operator=(const Client &) = delete; + Client &operator=(Client &&) = delete; + /** * This function connects to a remote server and returns whether the connect * succeeded. */ - bool Connect(const io::network::Endpoint &endpoint) { - if (!socket_.Connect(endpoint)) return false; - socket_.SetKeepAlive(); - socket_.SetNoDelay(); - return true; - } + bool Connect(const io::network::Endpoint &endpoint); /** * This function returns `true` if the socket is in an error state. */ - bool ErrorStatus() { return socket_.ErrorStatus(); } + bool ErrorStatus(); /** * This function shuts down the socket. */ - void Shutdown() { socket_.Shutdown(); } + void Shutdown(); /** * This function closes the socket. */ - void Close() { socket_.Close(); } + void Close(); /** * This function is used to receive `len` bytes from the socket and stores it * in an internal buffer. It returns `true` if the read succeeded and `false` * if it didn't. */ - bool Read(size_t len) { - size_t received = 0; - buffer_.write_end().Resize(buffer_.read_end().size() + len); - while (received < len) { - auto buff = buffer_.write_end().Allocate(); - int got = socket_.Read(buff.data, len - received); - if (got <= 0) return false; - buffer_.write_end().Written(got); - received += got; - } - return true; - } + bool Read(size_t len); /** * This function returns a pointer to the read data that is currently stored * in the client. */ - uint8_t *GetData() { return buffer_.read_end().data(); } + uint8_t *GetData(); /** * This function returns the size of the read data that is currently stored in * the client. */ - size_t GetDataSize() { return buffer_.read_end().size(); } + size_t GetDataSize(); /** * This function removes first `len` bytes from the data buffer. */ - void ShiftData(size_t len) { buffer_.read_end().Shift(len); } + void ShiftData(size_t len); /** * This function clears the data buffer. */ - void ClearData() { buffer_.read_end().Clear(); } + void ClearData(); - // Write end - bool Write(const uint8_t *data, size_t len, bool have_more = false) { - return socket_.Write(data, len, have_more); - } - bool Write(const std::string &str, bool have_more = false) { - return Write(reinterpret_cast(str.data()), str.size(), - have_more); - } + /** + * This function writes data to the socket. + * TODO (mferencevic): the `have_more` flag currently isn't supported when + * using OpenSSL + */ + bool Write(const uint8_t *data, size_t len, bool have_more = false); - const io::network::Endpoint &endpoint() { return socket_.endpoint(); } + /** + * This function writes data to the socket. + */ + bool Write(const std::string &str, bool have_more = false); + + const io::network::Endpoint &endpoint(); private: - io::network::Socket socket_; + void ReleaseSslObjects(); + io::network::Socket socket_; Buffer buffer_; + + ClientContext *context_; + SSL *ssl_{nullptr}; + BIO *bio_{nullptr}; }; /** * This class provides a stream-like input side object to the client. */ -class ClientInputStream { +class ClientInputStream final { public: - ClientInputStream(Client &client) : client_(client) {} + ClientInputStream(Client &client); - uint8_t *data() { return client_.GetData(); } + ClientInputStream(const ClientInputStream &) = delete; + ClientInputStream(ClientInputStream &&) = delete; + ClientInputStream &operator=(const ClientInputStream &) = delete; + ClientInputStream &operator=(ClientInputStream &&) = delete; - size_t size() const { return client_.GetDataSize(); } + uint8_t *data(); - void Shift(size_t len) { client_.ShiftData(len); } + size_t size() const; - void Clear() { client_.ClearData(); } + void Shift(size_t len); + + void Clear(); private: Client &client_; @@ -118,16 +133,18 @@ class ClientInputStream { /** * This class provides a stream-like output side object to the client. */ -class ClientOutputStream { +class ClientOutputStream final { public: - ClientOutputStream(Client &client) : client_(client) {} + ClientOutputStream(Client &client); - bool Write(const uint8_t *data, size_t len, bool have_more = false) { - return client_.Write(data, len, have_more); - } - bool Write(const std::string &str, bool have_more = false) { - return client_.Write(str, have_more); - } + ClientOutputStream(const ClientOutputStream &) = delete; + ClientOutputStream(ClientOutputStream &&) = delete; + ClientOutputStream &operator=(const ClientOutputStream &) = delete; + ClientOutputStream &operator=(ClientOutputStream &&) = delete; + + bool Write(const uint8_t *data, size_t len, bool have_more = false); + + bool Write(const std::string &str, bool have_more = false); private: Client &client_; diff --git a/src/communication/context.cpp b/src/communication/context.cpp new file mode 100644 index 000000000..0a2cb9b49 --- /dev/null +++ b/src/communication/context.cpp @@ -0,0 +1,80 @@ +#include + +#include "communication/context.hpp" + +namespace communication { + +ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) { + if (use_ssl_) { + ctx_ = SSL_CTX_new(TLS_client_method()); + CHECK(ctx_ != nullptr) << "Couldn't create client SSL_CTX object!"; + + // Disable legacy SSL support. Other options can be seen here: + // https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html + SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3); + } +} + +ClientContext::ClientContext(const std::string &key_file, + const std::string &cert_file) + : ClientContext(true) { + if (key_file != "" && cert_file != "") { + CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), + SSL_FILETYPE_PEM) == 1) + << "Couldn't load client certificate from file: " << cert_file; + CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), + SSL_FILETYPE_PEM) == 1) + << "Couldn't load client private key from file: " << key_file; + } +} + +SSL_CTX *ClientContext::context() { return ctx_; } + +bool ClientContext::use_ssl() { return use_ssl_; } + +ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {} + +ServerContext::ServerContext(const std::string &key_file, + const std::string &cert_file, + const std::string &ca_file, bool verify_peer) + : use_ssl_(true), ctx_(SSL_CTX_new(TLS_server_method())) { + // TODO (mferencevic): add support for encrypted private keys + // TODO (mferencevic): add certificate revocation list (CRL) + CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(), + SSL_FILETYPE_PEM) == 1) + << "Couldn't load server certificate from file: " << cert_file; + CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) == + 1) + << "Couldn't load server private key from file: " << key_file; + + // Disable legacy SSL support. Other options can be seen here: + // https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html + SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3); + + if (ca_file != "") { + // Load the certificate authority file. + CHECK(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1) + << "Couldn't load certificate authority from file: " << ca_file; + + if (verify_peer) { + // Add the CA to list of accepted CAs that is sent to the client. + STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str()); + CHECK(ca_names != nullptr) + << "Couldn't load certificate authority from file: " << ca_file; + // `ca_names` doesn' need to be free'd because we pass it to + // `SSL_CTX_set_client_CA_list`: + // https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html + SSL_CTX_set_client_CA_list(ctx_, ca_names); + + // Enable verification of the client certificate. + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +SSL_CTX *ServerContext::context() { return ctx_; } + +bool ServerContext::use_ssl() { return use_ssl_; } + +} // namespace communication diff --git a/src/communication/context.hpp b/src/communication/context.hpp new file mode 100644 index 000000000..3af50e69e --- /dev/null +++ b/src/communication/context.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include + +#include + +namespace communication { + +class ClientContext final { + public: + /** + * This constructor constructs a ClientContext that can either not use SSL + * (`use_ssl` is `false` by default), or it constructs a ClientContext that + * doesn't use a client certificate when `use_ssl` is set to `true`. + */ + explicit ClientContext(bool use_ssl = false); + + /** + * This constructor constructs a ClientContext that uses SSL and uses the + * specific client private key and certificate combination. If the parameters + * `key_file` and `cert_file` are equal to "" then the constructor falls back + * to the above constructor that uses SSL without certificates. + */ + ClientContext(const std::string &key_file, const std::string &cert_file); + + SSL_CTX *context(); + + bool use_ssl(); + + private: + bool use_ssl_; + SSL_CTX *ctx_; +}; + +class ServerContext final { + public: + /** + * This constructor constructs a ServerContext that doesn't use SSL. + */ + ServerContext(); + + /** + * This constructor constructs a ServerContext that uses SSL. The parameters + * `key_file` and `cert_file` can't be "" because when setting up a server it + * is mandatory to supply a private key and certificate. The parameter + * `ca_file` can be "" because SSL doesn't necessarily need to check that the + * client has a valid certificate. If you specify `verify_peer` to be `true` + * to check that the client certificate is valid, then you need to supply a + * valid `ca_file` as well. + */ + ServerContext(const std::string &key_file, const std::string &cert_file, + const std::string &ca_file = "", bool verify_peer = false); + + SSL_CTX *context(); + + bool use_ssl(); + + private: + bool use_ssl_; + SSL_CTX *ctx_; +}; + +} // namespace communication diff --git a/src/communication/helpers.cpp b/src/communication/helpers.cpp new file mode 100644 index 000000000..0f638a02b --- /dev/null +++ b/src/communication/helpers.cpp @@ -0,0 +1,13 @@ +#include + +#include "communication/helpers.hpp" + +namespace communication { + +const std::string SslGetLastError() { + char buff[2048]; + auto err = ERR_get_error(); + ERR_error_string_n(err, buff, sizeof(buff)); + return std::string(buff); +} +} // namespace communication diff --git a/src/communication/helpers.hpp b/src/communication/helpers.hpp new file mode 100644 index 000000000..99d9de2cf --- /dev/null +++ b/src/communication/helpers.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace communication { + +/** + * This function reads and returns a string describing the last OpenSSL error. + */ +const std::string SslGetLastError(); + +} // namespace communication diff --git a/src/communication/init.cpp b/src/communication/init.cpp new file mode 100644 index 000000000..c9d60b77b --- /dev/null +++ b/src/communication/init.cpp @@ -0,0 +1,21 @@ +#include + +#include +#include +#include + +#include "utils/signals.hpp" + +namespace communication { + +void Init() { + // Initialize the OpenSSL library. + SSL_library_init(); + OpenSSL_add_ssl_algorithms(); + SSL_load_error_strings(); + ERR_load_crypto_strings(); + + // Ignore SIGPIPE. + CHECK(utils::SignalIgnore(utils::Signal::Pipe)) << "Couldn't ignore SIGPIPE!"; +} +} // namespace communication diff --git a/src/communication/init.hpp b/src/communication/init.hpp new file mode 100644 index 000000000..0712cf383 --- /dev/null +++ b/src/communication/init.hpp @@ -0,0 +1,16 @@ +#pragma once + +namespace communication { + +/** + * Call this function in each `main` file that uses the Communication stack. It + * is used to initialize all libraries (primarily OpenSSL) and to fix some + * issues also related to OpenSSL (handling of SIGPIPE). + * + * Description of OpenSSL init can be seen here: + * https://wiki.openssl.org/index.php/Library_Initialization + * + * NOTE: This function must be called **exactly** once. + */ +void Init(); +} // namespace communication diff --git a/src/communication/listener.hpp b/src/communication/listener.hpp index 1de89f6ab..e0417f53e 100644 --- a/src/communication/listener.hpp +++ b/src/communication/listener.hpp @@ -13,6 +13,7 @@ #include "communication/session.hpp" #include "io/network/epoll.hpp" #include "io/network/socket.hpp" +#include "utils/signals.hpp" #include "utils/thread.hpp" #include "utils/thread/sync.hpp" @@ -28,7 +29,7 @@ namespace communication { * expired. */ template -class Listener { +class Listener final { private: // The maximum number of events handled per execution thread is 1. This is // because each event represents the start of a network request and it doesn't @@ -39,14 +40,27 @@ class Listener { using SessionHandler = Session; public: - Listener(TSessionData &data, int inactivity_timeout_sec, - const std::string &service_name) + Listener(TSessionData &data, ServerContext *context, + int inactivity_timeout_sec, const std::string &service_name, + size_t workers_count) : data_(data), alive_(true), + context_(context), inactivity_timeout_sec_(inactivity_timeout_sec), service_name_(service_name) { + std::cout << "Starting " << workers_count << " " << service_name_ + << " workers" << std::endl; + for (size_t i = 0; i < workers_count; ++i) { + worker_threads_.emplace_back([this, service_name, i]() { + utils::ThreadSetName(fmt::format("{} worker {}", service_name, i + 1)); + while (alive_) { + WaitAndProcessEvents(); + } + }); + } + if (inactivity_timeout_sec_ > 0) { - thread_ = std::thread([this, service_name]() { + timeout_thread_ = std::thread([this, service_name]() { utils::ThreadSetName(fmt::format("{} timeout", service_name)); while (alive_) { { @@ -72,7 +86,10 @@ class Listener { ~Listener() { alive_.store(false); - if (thread_.joinable()) thread_.join(); + if (timeout_thread_.joinable()) timeout_thread_.join(); + for (auto &worker_thread : worker_threads_) { + worker_thread.join(); + } } Listener(const Listener &) = delete; @@ -88,20 +105,12 @@ class Listener { void AddConnection(io::network::Socket &&connection) { std::unique_lock guard(lock_); - // Set connection options. - // The socket is left to be a blocking socket, but when `Read` is called - // then a flag is manually set to enable non-blocking read that is used in - // conjunction with `EPOLLET`. That means that the socket is used in a - // non-blocking fashion for reads and a blocking fashion for writes. - connection.SetKeepAlive(); - connection.SetNoDelay(); - // Remember fd before moving connection into Session. int fd = connection.fd(); // Create a new Session for the connection. sessions_.push_back(std::make_unique( - std::move(connection), data_, inactivity_timeout_sec_)); + std::move(connection), data_, context_, inactivity_timeout_sec_)); // Register the connection in Epoll. // We want to listen to an incoming event which is edge triggered and @@ -113,6 +122,7 @@ class Listener { sessions_.back().get()); } + private: /** * This function polls the event queue and processes incoming data. * It is thread safe and is intended to be called from multiple threads and @@ -168,7 +178,6 @@ class Listener { } } - private: bool ExecuteSession(SessionHandler &session) { try { if (session.Execute()) { @@ -223,8 +232,11 @@ class Listener { utils::SpinLock lock_; std::vector> sessions_; - std::thread thread_; + std::thread timeout_thread_; + std::vector worker_threads_; std::atomic alive_; + + ServerContext *context_; const int inactivity_timeout_sec_; const std::string service_name_; }; diff --git a/src/communication/rpc/client.cpp b/src/communication/rpc/client.cpp index 0d99740bf..6a1c9b0fa 100644 --- a/src/communication/rpc/client.cpp +++ b/src/communication/rpc/client.cpp @@ -30,7 +30,7 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send( // Connect to the remote server. if (!client_) { - client_.emplace(); + client_.emplace(&context_); if (!client_->Connect(endpoint_)) { LOG(ERROR) << "Couldn't connect to remote address " << endpoint_; client_ = std::experimental::nullopt; diff --git a/src/communication/rpc/client.hpp b/src/communication/rpc/client.hpp index 47d7cffb4..971fb558f 100644 --- a/src/communication/rpc/client.hpp +++ b/src/communication/rpc/client.hpp @@ -86,6 +86,8 @@ class Client { ::capnp::MessageBuilder *message); io::network::Endpoint endpoint_; + // TODO (mferencevic): currently the RPC client is hardcoded not to use SSL + communication::ClientContext context_; std::experimental::optional client_; std::mutex mutex_; diff --git a/src/communication/rpc/server.cpp b/src/communication/rpc/server.cpp index 5304a698c..e0c697863 100644 --- a/src/communication/rpc/server.cpp +++ b/src/communication/rpc/server.cpp @@ -4,7 +4,7 @@ namespace communication::rpc { Server::Server(const io::network::Endpoint &endpoint, size_t workers_count) - : server_(endpoint, *this, -1, "RPC", workers_count) {} + : server_(endpoint, *this, &context_, -1, "RPC", workers_count) {} void Server::StopProcessingCalls() { server_.Shutdown(); diff --git a/src/communication/rpc/server.hpp b/src/communication/rpc/server.hpp index a39a25dd3..3a28efb01 100644 --- a/src/communication/rpc/server.hpp +++ b/src/communication/rpc/server.hpp @@ -78,6 +78,8 @@ class Server { ConcurrentMap callbacks_; std::mutex mutex_; + // TODO (mferencevic): currently the RPC server is hardcoded not to use SSL + communication::ServerContext context_; communication::Server server_; }; // namespace communication::rpc diff --git a/src/communication/server.hpp b/src/communication/server.hpp index 4e63e5778..9a57e63eb 100644 --- a/src/communication/server.hpp +++ b/src/communication/server.hpp @@ -10,6 +10,7 @@ #include #include +#include "communication/init.hpp" #include "communication/listener.hpp" #include "io/network/socket.hpp" #include "utils/thread.hpp" @@ -27,6 +28,9 @@ namespace communication { * Current Server achitecture: * incoming connection -> server -> listener -> session * + * NOTE: If you use this server you **must** call `communication::Init()` from + * the `main` function before using the server! + * * @tparam TSession the server can handle different Sessions, each session * represents a different protocol so the same network infrastructure * can be used for handling different protocols @@ -34,7 +38,7 @@ namespace communication { * session */ template -class Server { +class Server final { public: using Socket = io::network::Socket; @@ -43,9 +47,11 @@ class Server { * invokes workers_count workers */ Server(const io::network::Endpoint &endpoint, TSessionData &session_data, - int inactivity_timeout_sec, const std::string &service_name, + ServerContext *context, int inactivity_timeout_sec, + const std::string &service_name, size_t workers_count = std::thread::hardware_concurrency()) - : listener_(session_data, inactivity_timeout_sec, service_name), + : listener_(session_data, context, inactivity_timeout_sec, service_name, + workers_count), service_name_(service_name) { // Without server we can't continue with application so we can just // terminate here. @@ -58,18 +64,7 @@ class Server { } thread_ = std::thread([this, workers_count, service_name]() { - std::cout << "Starting " << workers_count << " " << service_name - << " workers" << std::endl; utils::ThreadSetName(fmt::format("{} server", service_name)); - for (size_t i = 0; i < workers_count; ++i) { - worker_threads_.emplace_back([this, service_name, i]() { - utils::ThreadSetName( - fmt::format("{} worker {}", service_name, i + 1)); - while (alive_) { - listener_.WaitAndProcessEvents(); - } - }); - } std::cout << service_name << " server is fully armed and operational" << std::endl; @@ -81,9 +76,6 @@ class Server { } std::cout << service_name << " shutting down..." << std::endl; - for (auto &worker_thread : worker_threads_) { - worker_thread.join(); - } }); } @@ -128,7 +120,6 @@ class Server { std::atomic alive_{true}; std::thread thread_; - std::vector worker_threads_; Socket socket_; Listener listener_; diff --git a/src/communication/session.hpp b/src/communication/session.hpp index 2cbab652e..c8db4a909 100644 --- a/src/communication/session.hpp +++ b/src/communication/session.hpp @@ -9,7 +9,13 @@ #include +#include +#include +#include + #include "communication/buffer.hpp" +#include "communication/context.hpp" +#include "communication/helpers.hpp" #include "io/network/socket.hpp" #include "io/network/stream_buffer.hpp" #include "utils/exceptions.hpp" @@ -35,12 +41,19 @@ using InputStream = Buffer::ReadEnd; * This is used to provide output from user sessions. All sessions used with the * network stack should use this class for their output stream. */ -class OutputStream { +class OutputStream final { public: - OutputStream(io::network::Socket &socket) : socket_(socket) {} + OutputStream( + std::function write_function) + : write_function_(write_function) {} + + OutputStream(const OutputStream &) = delete; + OutputStream(OutputStream &&) = delete; + OutputStream &operator=(const OutputStream &) = delete; + OutputStream &operator=(OutputStream &&) = delete; bool Write(const uint8_t *data, size_t len, bool have_more = false) { - return socket_.Write(data, len, have_more); + return write_function_(data, len, have_more); } bool Write(const std::string &str, bool have_more = false) { @@ -49,7 +62,7 @@ class OutputStream { } private: - io::network::Socket &socket_; + std::function write_function_; }; /** @@ -58,20 +71,73 @@ class OutputStream { * wrapping. */ template -class Session { +class Session final { public: Session(io::network::Socket &&socket, TSessionData &data, - int inactivity_timeout_sec) + ServerContext *context, int inactivity_timeout_sec) : socket_(std::move(socket)), - output_stream_(socket_), + output_stream_([this](const uint8_t *data, size_t len, bool have_more) { + return Write(data, len, have_more); + }), session_(data, input_buffer_.read_end(), output_stream_), - inactivity_timeout_sec_(inactivity_timeout_sec) {} + inactivity_timeout_sec_(inactivity_timeout_sec) { + // Set socket options. + // The socket is set to be a non-blocking socket. We use the socket in a + // non-blocking fashion for reads and manually simulate a blocking socket + // type for writes. This manual handling of writes is necessary because + // OpenSSL doesn't provide a way to add `recv` parameters to the `SSL_read` + // call so we can't have a blocking socket and use it in a non-blocking way + // only for reads. + // Keep alive is enabled so that the Kernel's TCP stack notifies us if a + // connection is broken and shouldn't be used anymore. + // Because we manually always set the `have_more` flag to the socket + // `Write` call we can disable the Nagle algorithm because we know that we + // are always sending optimal packets. Even if we don't send optimal + // packets, there will be no delay between packets and throughput won't + // suffer. + socket_.SetNonBlocking(); + socket_.SetKeepAlive(); + socket_.SetNoDelay(); + + // Prepare SSL if we should be using it. + if (context->use_ssl()) { + // Create a new SSL object that will be used for SSL communication. + ssl_ = SSL_new(context->context()); + CHECK(ssl_ != nullptr) << "Couldn't create server SSL object!"; + + // Create a new BIO (block I/O) SSL object so that OpenSSL can communicate + // using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that + // it doesn't need to close the socket when destructing all objects (we + // handle that in our socket destructor). + bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE); + CHECK(bio_ != nullptr) << "Couldn't create server BIO object!"; + + // Connect the BIO object to the SSL object so that OpenSSL knows which + // stream it should use for communication. We use the same object for both + // the read and write end. This function cannot fail. + SSL_set_bio(ssl_, bio_, bio_); + + // Indicate to OpenSSL that this connection is a server. The TLS handshake + // will be performed in the first `SSL_read` or `SSL_write` call. This + // function cannot fail. + SSL_set_accept_state(ssl_); + } + } Session(const Session &) = delete; Session(Session &&) = delete; Session &operator=(const Session &) = delete; Session &operator=(Session &&) = delete; + ~Session() { + // If we are using SSL we need to free the allocated objects. Here we only + // free the SSL object because the `SSL_free` function also automatically + // frees the BIO object. + if (ssl_) { + SSL_free(ssl_); + } + } + /** * This function is called from the communication stack when an event occurs * indicating that there is data waiting to be read. This function calls the @@ -87,29 +153,69 @@ class Session { // Allocate the buffer to fill the data. auto buf = input_buffer_.write_end().Allocate(); - // Read from the buffer at most buf.len bytes in a non-blocking fashion. - int len = socket_.Read(buf.data, buf.len, true); - // Check for read errors. - if (len == -1) { - // This means read would block or read was interrupted by signal, we - // return `true` to indicate that all data is processad and to stop - // reading of data. - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - return true; + if (ssl_) { + // We clear errors here to prevent errors piling up in the internal + // OpenSSL error queue. To see when could that be an issue read this: + // https://www.arangodb.com/2014/07/started-hate-openssl/ + ERR_clear_error(); + + // Read data from the socket using the OpenSSL API. + auto len = SSL_read(ssl_, buf.data, buf.len); + + // Check for read errors. + if (len < 0) { + auto err = SSL_get_error(ssl_, len); + if (err == SSL_ERROR_WANT_READ) { + // OpenSSL want's to read more data from the socket. We return `true` + // to stop execution of the session to wait for more data to be + // received. + return true; + } else if (err == SSL_ERROR_WANT_WRITE) { + // The OpenSSL library wants to perfrom some kind of handshake so we + // wait for the socket to become ready for a write and call the read + // again. We return `false` so that the listener calls this function + // again. + socket_.WaitForReadyWrite(); + return false; + } else { + // This is a fatal error. + throw utils::BasicException(SslGetLastError()); + } + } else if (len == 0) { + // The client closed the connection. + throw SessionClosedException("Session was closed by the client."); + return false; + } else { + // Notify the input buffer that it has new data. + input_buffer_.write_end().Written(len); } - // Some other error occurred, throw an exception to start session cleanup. - throw utils::BasicException("Couldn't read data from socket!"); - } + } else { + // Read from the buffer at most buf.len bytes in a non-blocking fashion. + // Note, the `true` parameter for non-blocking here is redundant because + // the socket already is non-blocking. + auto len = socket_.Read(buf.data, buf.len, true); - // The client has closed the connection. - if (len == 0) { - throw SessionClosedException("Session was closed by client."); + // Check for read errors. + if (len == -1) { + // This means read would block or read was interrupted by signal, we + // return `true` to indicate that all data is processad and to stop + // reading of data. + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + return true; + } + // Some other error occurred, throw an exception to start session + // cleanup. + throw utils::BasicException("Couldn't read data from the socket!"); + } else if (len == 0) { + // The client has closed the connection. + throw SessionClosedException("Session was closed by client."); + } else { + // Notify the input buffer that it has new data. + input_buffer_.write_end().Written(len); + } } - // Notify the input buffer that it has new data. - input_buffer_.write_end().Written(len); - // Execute the session. session_.Execute(); @@ -142,6 +248,52 @@ class Session { last_event_time_ = std::chrono::steady_clock::now(); } + // TODO (mferencevic): the `have_more` flag currently isn't supported + // when using OpenSSL + bool Write(const uint8_t *data, size_t len, bool have_more = false) { + if (ssl_) { + // `SSL_write` has the interface of a normal `write` call. Because of that + // we need to ensure that all data is written to the socket manually. + while (len > 0) { + // We clear errors here to prevent errors piling up in the internal + // OpenSSL error queue. To see when could that be an issue read this: + // https://www.arangodb.com/2014/07/started-hate-openssl/ + ERR_clear_error(); + + // Write data to the socket using OpenSSL. + auto written = SSL_write(ssl_, data, len); + if (written < 0) { + auto err = SSL_get_error(ssl_, written); + if (err == SSL_ERROR_WANT_READ) { + // OpenSSL wants to perform some kind of handshake, we need to + // ensure that there is data available for the next call to + // `SSL_write`. + socket_.WaitForReadyRead(); + } else if (err == SSL_ERROR_WANT_WRITE) { + // The socket probably returned WOULDBLOCK and we need to wait for + // the output buffers to clear and reattempt the send. + socket_.WaitForReadyWrite(); + } else { + // This is a fatal error. + return false; + } + } else if (written == 0) { + // The client closed the connection. + return false; + } else { + len -= written; + data += written; + } + } + return true; + } else { + // This function guarantees that all data will be written to the socket + // even if the socket is non-blocking. It will use a non-busy wait to send + // all data. + return socket_.Write(data, len, have_more); + } + } + // We own the socket. io::network::Socket socket_; @@ -157,5 +309,9 @@ class Session { std::chrono::steady_clock::now()}; utils::SpinLock lock_; const int inactivity_timeout_sec_; -}; + + // SSL objects. + SSL *ssl_{nullptr}; + BIO *bio_{nullptr}; +}; // namespace communication } // namespace communication diff --git a/src/io/network/socket.cpp b/src/io/network/socket.cpp index b2ea57c2b..1c962d266 100644 --- a/src/io/network/socket.cpp +++ b/src/io/network/socket.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -219,8 +220,13 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) { // Terminal error, return failure. return false; } - // Non-fatal error, retry. - continue; + // Non-fatal error, retry after the socket is ready. This is here to + // implement a non-busy wait. If we just continue with the loop we have a + // busy wait. + if (!WaitForReadyWrite()) return false; + } else if (written == 0) { + // The client closed the connection. + return false; } else { len -= written; data += written; @@ -234,7 +240,32 @@ bool Socket::Write(const std::string &s, bool have_more) { have_more); } -int Socket::Read(void *buffer, size_t len, bool nonblock) { +ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) { return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0); } + +bool Socket::WaitForReadyRead() { + struct pollfd p; + p.fd = socket_; + p.events = POLLIN; + // We call poll with one element in the poll fds array (first and second + // arguments), also we set the timeout to -1 to block indefinitely until an + // event occurs. + int ret = poll(&p, 1, -1); + if (ret < 1) return false; + return p.revents & POLLIN; +} + +bool Socket::WaitForReadyWrite() { + struct pollfd p; + p.fd = socket_; + p.events = POLLOUT; + // We call poll with one element in the poll fds array (first and second + // arguments), also we set the timeout to -1 to block indefinitely until an + // event occurs. + int ret = poll(&p, 1, -1); + if (ret < 1) return false; + return p.revents & POLLOUT; +} + } // namespace io::network diff --git a/src/io/network/socket.hpp b/src/io/network/socket.hpp index abc47968a..83b943062 100644 --- a/src/io/network/socket.hpp +++ b/src/io/network/socket.hpp @@ -153,7 +153,41 @@ class Socket { * == 0 if the client closed the connection * < 0 if an error has occurred */ - int Read(void *buffer, size_t len, bool nonblock = false); + ssize_t Read(void *buffer, size_t len, bool nonblock = false); + + /** + * Wait until the socket becomes ready for a `Read` operation. + * This function blocks indefinitely waiting for the socket to change its + * state. This function is useful when you need a blocking operation on a + * non-blocking socket, you can call this function to ensure that your next + * `Read` operation will succeed. + * + * The function returns `true` if the wait succeded (there is data waiting to + * be read from the socket) and returns `false` if the wait failed (the socket + * was closed or something else bad happened). + * + * @return wait success status: + * true if the wait succeeded + * false if the wait failed + */ + bool WaitForReadyRead(); + + /** + * Wait until the socket becomes ready for a `Write` operation. + * This function blocks indefinitely waiting for the socket to change its + * state. This function is useful when you need a blocking operation on a + * non-blocking socket, you can call this function to ensure that your next + * `Write` operation will succeed. + * + * The function returns `true` if the wait succeded (the socket can be written + * to) and returns `false` if the wait failed (the socket was closed or + * something else bad happened). + * + * @return wait success status: + * true if the wait succeeded + * false if the wait failed + */ + bool WaitForReadyWrite(); private: Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {} diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index 4e22b9995..56b231c6a 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -28,6 +28,7 @@ using communication::bolt::SessionData; using SessionT = communication::bolt::Session; using ServerT = communication::Server; +using communication::ServerContext; // General purpose flags. DEFINE_string(interface, "0.0.0.0", @@ -41,6 +42,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800, "Time in seconds after which inactive sessions will be " "closed.", FLAG_IN_RANGE(1, INT32_MAX)); +DEFINE_string(cert_file, "", "Certificate file to use."); +DEFINE_string(key_file, "", "Key file to use."); DEFINE_string(log_file, "", "Path to where the log should be stored."); DEFINE_HIDDEN_string( log_link_basename, "", @@ -142,6 +145,9 @@ int WithInit(int argc, char **argv, stats::InitStatsLogging(get_stats_prefix()); utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); }); + // Initialize the communication library. + communication::Init(); + // Start memory warning logger. utils::Scheduler mem_log_scheduler; if (FLAGS_memory_warning_threshold > 0) { @@ -160,9 +166,17 @@ void SingleNodeMain() { google::SetUsageMessage("Memgraph single-node database server"); database::SingleNode db; SessionData session_data{db}; + + ServerContext context; + std::string service_name = "Bolt"; + if (FLAGS_key_file != "" && FLAGS_cert_file != "") { + context = ServerContext(FLAGS_key_file, FLAGS_cert_file); + service_name = "BoltS"; + } + ServerT server({FLAGS_interface, static_cast(FLAGS_port)}, - session_data, FLAGS_session_inactivity_timeout, "Bolt", - FLAGS_num_workers); + session_data, &context, FLAGS_session_inactivity_timeout, + service_name, FLAGS_num_workers); // Setup telemetry std::experimental::optional telemetry; @@ -214,9 +228,17 @@ void MasterMain() { database::Master db; SessionData session_data{db}; + + ServerContext context; + std::string service_name = "Bolt"; + if (FLAGS_key_file != "" && FLAGS_cert_file != "") { + context = ServerContext(FLAGS_key_file, FLAGS_cert_file); + service_name = "BoltS"; + } + ServerT server({FLAGS_interface, static_cast(FLAGS_port)}, - session_data, FLAGS_session_inactivity_timeout, "Bolt", - FLAGS_num_workers); + session_data, &context, FLAGS_session_inactivity_timeout, + service_name, FLAGS_num_workers); // Handler for regular termination signals auto shutdown = [&server] { diff --git a/tests/concurrent/network_common.hpp b/tests/concurrent/network_common.hpp index f8a5a430b..bc985767c 100644 --- a/tests/concurrent/network_common.hpp +++ b/tests/concurrent/network_common.hpp @@ -42,10 +42,11 @@ class TestSession { input_stream_.Shift(size + 2); } - communication::InputStream input_stream_; - communication::OutputStream output_stream_; + communication::InputStream &input_stream_; + communication::OutputStream &output_stream_; }; +using ContextT = communication::ServerContext; using ServerT = communication::Server; void client_run(int num, const char *interface, uint16_t port, diff --git a/tests/concurrent/network_read_hang.cpp b/tests/concurrent/network_read_hang.cpp index d3a3dbeaf..c78d6e2f4 100644 --- a/tests/concurrent/network_read_hang.cpp +++ b/tests/concurrent/network_read_hang.cpp @@ -32,8 +32,8 @@ class TestSession { output_stream_.Write(input_stream_.data(), input_stream_.size()); } - communication::InputStream input_stream_; - communication::OutputStream output_stream_; + communication::InputStream &input_stream_; + communication::OutputStream &output_stream_; }; std::atomic run{true}; @@ -63,8 +63,9 @@ TEST(Network, SocketReadHangOnConcurrentConnections) { TestData data; int N = (std::thread::hardware_concurrency() + 1) / 2; int Nc = N * 3; - communication::Server server(endpoint, data, -1, - "Test", N); + communication::ServerContext context; + communication::Server server(endpoint, data, &context, + -1, "Test", N); const auto &ep = server.endpoint(); // start clients diff --git a/tests/concurrent/network_server.cpp b/tests/concurrent/network_server.cpp index bf5a6dcff..ec0c89c96 100644 --- a/tests/concurrent/network_server.cpp +++ b/tests/concurrent/network_server.cpp @@ -21,7 +21,8 @@ TEST(Network, Server) { // initialize server TestData session_data; int N = (std::thread::hardware_concurrency() + 1) / 2; - ServerT server(endpoint, session_data, -1, "Test", N); + ContextT context; + ServerT server(endpoint, session_data, &context, -1, "Test", N); const auto &ep = server.endpoint(); // start clients diff --git a/tests/concurrent/network_session_leak.cpp b/tests/concurrent/network_session_leak.cpp index 6e34e2bb3..6c868c645 100644 --- a/tests/concurrent/network_session_leak.cpp +++ b/tests/concurrent/network_session_leak.cpp @@ -22,7 +22,8 @@ TEST(Network, SessionLeak) { // initialize server TestData session_data; - ServerT server(endpoint, session_data, -1, "Test", 2); + ContextT context; + ServerT server(endpoint, session_data, &context, -1, "Test", 2); // start clients int N = 50; diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 7ac72a0c4..a646112d7 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -1,2 +1,5 @@ # telemetry test binaries add_subdirectory(telemetry) + +# ssl test binaries +add_subdirectory(ssl) diff --git a/tests/integration/apollo_runs.yaml b/tests/integration/apollo_runs.yaml index 96c95d6b9..cce5c3078 100644 --- a/tests/integration/apollo_runs.yaml +++ b/tests/integration/apollo_runs.yaml @@ -6,3 +6,11 @@ - server.py # server script - ../../../build_debug/tests/integration/telemetry/client # client binary - ../../../build_debug/tests/manual/kvstore_console # kvstore console binary + +- name: integration__ssl + cd: ssl + commands: ./runner.sh + infiles: + - runner.sh # runner script + - ../../../build_debug/tests/integration/ssl/tester # tester binary + enable_network: true diff --git a/tests/integration/ssl/CMakeLists.txt b/tests/integration/ssl/CMakeLists.txt new file mode 100644 index 000000000..d860f9cbc --- /dev/null +++ b/tests/integration/ssl/CMakeLists.txt @@ -0,0 +1,6 @@ +set(target_name memgraph__integration__ssl) +set(tester_target_name ${target_name}__tester) + +add_executable(${tester_target_name} tester.cpp) +set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester) +target_link_libraries(${tester_target_name} memgraph_lib kvstore_dummy_lib) diff --git a/tests/integration/ssl/runner.sh b/tests/integration/ssl/runner.sh new file mode 100755 index 000000000..57364fb41 --- /dev/null +++ b/tests/integration/ssl/runner.sh @@ -0,0 +1,78 @@ +#!/bin/bash -e + +pushd () { command pushd "$@" > /dev/null; } +popd () { command popd "$@" > /dev/null; } +die () { printf "\033[1;31m~~ Test failed! ~~\033[0m\n\n"; exit 1; } + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$DIR" + +tmpdir=/tmp/memgraph_integration_ssl + +if [ -d $tmpdir ]; then + rm -rf $tmpdir +fi +mkdir -p $tmpdir +cd $tmpdir + +easyrsa="EasyRSA-3.0.4" +wget -nv http://deps.memgraph.io/$easyrsa.tgz + +tar -xf $easyrsa.tgz +mv $easyrsa ca1 +cp -r ca1 ca2 + +for i in ca1 ca2; do + pushd $i + ./easyrsa --batch init-pki + ./easyrsa --batch build-ca nopass + ./easyrsa --batch build-server-full server nopass + ./easyrsa --batch build-client-full client nopass + popd +done + +binary_dir="$DIR/../../../build" +if [ ! -d $binary_dir ]; then + binary_dir="$DIR/../../../build_debug" +fi + +set +e + +echo +for i in ca1 ca2; do + for j in none ca1 ca2; do + for k in false true; do + printf "\033[1;36m~~ Server CA: $i; Client CA: $j; Verify peer: $k ~~\033[0m\n" + + if [ "$j" == "none" ]; then + client_key="" + client_cert="" + else + client_key=$j/pki/private/client.key + client_cert=$j/pki/issued/client.crt + fi + + $binary_dir/tests/integration/ssl/tester \ + --server-key-file=$i/pki/private/server.key \ + --server-cert-file=$i/pki/issued/server.crt \ + --server-ca-file=$i/pki/ca.crt \ + --server-verify-peer=$k \ + --client-key-file=$client_key \ + --client-cert-file=$client_cert + + exitcode=$? + + if [ "$i" == "$j" ]; then + [ $exitcode -eq 0 ] || die + else + if $k; then + [ $exitcode -ne 0 ] || die + else + [ $exitcode -eq 0 ] || die + fi + fi + + printf "\033[1;32m~~ Test ok! ~~\033[0m\n\n" + done + done +done diff --git a/tests/integration/ssl/tester.cpp b/tests/integration/ssl/tester.cpp new file mode 100644 index 000000000..ca7557490 --- /dev/null +++ b/tests/integration/ssl/tester.cpp @@ -0,0 +1,76 @@ +#include + +#include +#include + +#include "communication/client.hpp" +#include "communication/server.hpp" +#include "utils/exceptions.hpp" + +DEFINE_string(server_cert_file, "", "Server certificate file to use."); +DEFINE_string(server_key_file, "", "Server key file to use."); +DEFINE_string(server_ca_file, "", "Server CA file to use."); +DEFINE_bool(server_verify_peer, false, "Set to true to verify the peer."); + +DEFINE_string(client_cert_file, "", "Client certificate file to use."); +DEFINE_string(client_key_file, "", "Client key file to use."); + +const std::string message = "ssl echo test"; + +struct EchoData {}; + +class EchoSession { + public: + EchoSession(EchoData &, communication::InputStream &input_stream, + communication::OutputStream &output_stream) + : input_stream_(input_stream), output_stream_(output_stream) {} + + void Execute() { + if (input_stream_.size() < message.size()) return; + LOG(INFO) << "Server received message."; + if (!output_stream_.Write(input_stream_.data(), message.size())) { + throw utils::BasicException("Output stream write failed!"); + } + LOG(INFO) << "Server sent message."; + input_stream_.Shift(message.size()); + } + + private: + communication::InputStream &input_stream_; + communication::OutputStream &output_stream_; +}; + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + // Initialize the communication stack. + communication::Init(); + + // Initialize the server. + EchoData echo_data; + communication::ServerContext server_context( + FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file, + FLAGS_server_verify_peer); + communication::Server server( + {"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1); + + // Initialize the client. + communication::ClientContext client_context(FLAGS_client_key_file, + FLAGS_client_cert_file); + communication::Client client(&client_context); + + // Connect to the server. + CHECK(client.Connect(server.endpoint())) << "Couldn't connect to server!"; + + // Perform echo. + CHECK(client.Write(message)) << "Client couldn't send message!"; + LOG(INFO) << "Client sent message."; + CHECK(client.Read(message.size())) << "Client couldn't receive message!"; + LOG(INFO) << "Client received message."; + CHECK(std::string(reinterpret_cast(client.GetData()), + message.size()) == message) + << "Received message isn't equal to sent message!"; + + return 0; +} diff --git a/tests/macro_benchmark/clients/bolt_client.hpp b/tests/macro_benchmark/clients/bolt_client.hpp index 110f7cf23..b362b939b 100644 --- a/tests/macro_benchmark/clients/bolt_client.hpp +++ b/tests/macro_benchmark/clients/bolt_client.hpp @@ -10,6 +10,7 @@ #include "io/network/endpoint.hpp" using EndpointT = io::network::Endpoint; +using ContextT = communication::ClientContext; using ClientT = communication::bolt::Client; using QueryDataT = communication::bolt::QueryData; using communication::bolt::DecodedValue; @@ -18,23 +19,23 @@ class BoltClient { public: BoltClient(const std::string &address, uint16_t port, const std::string &username, const std::string &password, - const std::string & = "") { + const std::string & = "", bool use_ssl = false) + : context_(use_ssl), client_(context_) { EndpointT endpoint(address, port); - client_ = std::make_unique(); - if (!client_->Connect(endpoint, username, password)) { + if (!client_.Connect(endpoint, username, password)) { LOG(FATAL) << "Could not connect to: " << endpoint; } - } QueryDataT Execute(const std::string &query, const std::map ¶meters) { - return client_->Execute(query, parameters); + return client_.Execute(query, parameters); } - void Close() { client_->Close(); } + void Close() { client_.Close(); } private: - std::unique_ptr client_; + ContextT context_; + ClientT client_; }; diff --git a/tests/macro_benchmark/clients/card_fraud_client.cpp b/tests/macro_benchmark/clients/card_fraud_client.cpp index 08ba98cb3..fdba02efb 100644 --- a/tests/macro_benchmark/clients/card_fraud_client.cpp +++ b/tests/macro_benchmark/clients/card_fraud_client.cpp @@ -331,11 +331,14 @@ int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + communication::Init(); + stats::InitStatsLogging( fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario)); Endpoint endpoint(FLAGS_address, FLAGS_port); - Client client; + ClientContext context(FLAGS_use_ssl); + Client client(&context); if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { LOG(FATAL) << "Couldn't connect to " << endpoint; } diff --git a/tests/macro_benchmark/clients/common.hpp b/tests/macro_benchmark/clients/common.hpp index b7039b690..f1583a6d5 100644 --- a/tests/macro_benchmark/clients/common.hpp +++ b/tests/macro_benchmark/clients/common.hpp @@ -11,6 +11,7 @@ #include "utils/exceptions.hpp" #include "utils/timer.hpp" +using communication::ClientContext; using communication::bolt::Client; using communication::bolt::DecodedValue; using io::network::Endpoint; diff --git a/tests/macro_benchmark/clients/long_running_common.hpp b/tests/macro_benchmark/clients/long_running_common.hpp index 40b94afb0..8ab7311ee 100644 --- a/tests/macro_benchmark/clients/long_running_common.hpp +++ b/tests/macro_benchmark/clients/long_running_common.hpp @@ -16,6 +16,7 @@ DEFINE_int32(num_workers, 1, "Number of workers"); DEFINE_string(output, "", "Output file"); DEFINE_string(username, "", "Username for the database"); DEFINE_string(password, "", "Password for the database"); +DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); DEFINE_int32(duration, 30, "Number of seconds to execute benchmark"); DEFINE_string(group, "unknown", "Test group name"); @@ -97,7 +98,8 @@ class TestClient { std::thread runner_thread_; private: - Client client_; + communication::ClientContext context_{FLAGS_use_ssl}; + Client client_{&context_}; }; void RunMultithreadedTest(std::vector> &clients) { diff --git a/tests/macro_benchmark/clients/pokec_client.cpp b/tests/macro_benchmark/clients/pokec_client.cpp index f1c191eb2..cf19edbea 100644 --- a/tests/macro_benchmark/clients/pokec_client.cpp +++ b/tests/macro_benchmark/clients/pokec_client.cpp @@ -271,12 +271,15 @@ int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + communication::Init(); + nlohmann::json config; std::cin >> config; auto independent_nodes_ids = [&] { Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port); - Client client; + ClientContext context(FLAGS_use_ssl); + Client client(&context); if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { LOG(FATAL) << "Couldn't connect to " << endpoint; } diff --git a/tests/macro_benchmark/clients/query_client.cpp b/tests/macro_benchmark/clients/query_client.cpp index 5d28b04f6..224465265 100644 --- a/tests/macro_benchmark/clients/query_client.cpp +++ b/tests/macro_benchmark/clients/query_client.cpp @@ -19,6 +19,7 @@ DEFINE_string(address, "127.0.0.1", "Server address"); DEFINE_int32(port, 7687, "Server port"); DEFINE_string(username, "", "Username for the database"); DEFINE_string(password, "", "Password for the database"); +DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); using communication::bolt::DecodedValue; @@ -58,7 +59,8 @@ void ExecuteQueries(const std::vector &queries, for (int i = 0; i < FLAGS_num_workers; ++i) { threads.push_back(std::thread([&]() { Endpoint endpoint(FLAGS_address, FLAGS_port); - Client client; + ClientContext context(FLAGS_use_ssl); + Client client(&context); if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { LOG(FATAL) << "Couldn't connect to " << endpoint; } @@ -100,6 +102,8 @@ int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + communication::Init(); + std::ifstream ifile; std::istream *istream{&std::cin}; diff --git a/tests/manual/CMakeLists.txt b/tests/manual/CMakeLists.txt index 854aa3a39..aab46eb77 100644 --- a/tests/manual/CMakeLists.txt +++ b/tests/manual/CMakeLists.txt @@ -71,6 +71,12 @@ target_link_libraries(${test_prefix}sl_position_and_count memgraph_lib kvstore_d add_manual_test(stripped_timing.cpp) target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib) +add_manual_test(ssl_client.cpp) +target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib) + +add_manual_test(ssl_server.cpp) +target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib) + add_manual_test(xorshift.cpp) target_link_libraries(${test_prefix}xorshift mg-utils) diff --git a/tests/manual/bolt_client.cpp b/tests/manual/bolt_client.cpp index facb9fae2..57f62847a 100644 --- a/tests/manual/bolt_client.cpp +++ b/tests/manual/bolt_client.cpp @@ -10,15 +10,20 @@ DEFINE_string(address, "127.0.0.1", "Server address"); DEFINE_int32(port, 7687, "Server port"); DEFINE_string(username, "", "Username for the database"); DEFINE_string(password, "", "Password for the database"); +DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + communication::Init(); + // TODO: handle endpoint exception io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port); - communication::bolt::Client client; + + communication::ClientContext context(FLAGS_use_ssl); + communication::bolt::Client client(&context); if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1; diff --git a/tests/manual/ssl_client.cpp b/tests/manual/ssl_client.cpp new file mode 100644 index 000000000..dad40798d --- /dev/null +++ b/tests/manual/ssl_client.cpp @@ -0,0 +1,65 @@ +#include +#include + +#include "communication/client.hpp" +#include "io/network/endpoint.hpp" +#include "utils/timer.hpp" + +DEFINE_string(address, "127.0.0.1", "Server address"); +DEFINE_int32(port, 54321, "Server port"); +DEFINE_string(cert_file, "", "Certificate file to use."); +DEFINE_string(key_file, "", "Key file to use."); + +bool EchoMessage(communication::Client &client, const std::string &data) { + uint16_t size = data.size(); + if (!client.Write(reinterpret_cast(&size), sizeof(size))) { + LOG(WARNING) << "Couldn't send data size!"; + return false; + } + if (!client.Write(data)) { + LOG(WARNING) << "Couldn't send data!"; + return false; + } + + client.ClearData(); + if (!client.Read(size)) { + LOG(WARNING) << "Couldn't receive data!"; + return false; + } + if (std::string(reinterpret_cast(client.GetData()), size) != + data) { + LOG(WARNING) << "Received data isn't equal to sent data!"; + return false; + } + return true; +} + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + communication::Init(); + + io::network::Endpoint endpoint(FLAGS_address, FLAGS_port); + + communication::ClientContext context(FLAGS_key_file, FLAGS_cert_file); + communication::Client client(&context); + + if (!client.Connect(endpoint)) return 1; + + bool success = true; + while (true) { + std::string s; + std::getline(std::cin, s); + if (s == "") break; + if (!EchoMessage(client, s)) { + success = false; + break; + } + } + + // Send server shutdown signal. The call will fail, we don't care. + EchoMessage(client, ""); + + return success ? 0 : 1; +} diff --git a/tests/manual/ssl_server.cpp b/tests/manual/ssl_server.cpp new file mode 100644 index 000000000..cfdbee1fb --- /dev/null +++ b/tests/manual/ssl_server.cpp @@ -0,0 +1,73 @@ +#include + +#include +#include + +#include "communication/server.hpp" +#include "utils/exceptions.hpp" + +DEFINE_string(address, "127.0.0.1", "Server address"); +DEFINE_int32(port, 54321, "Server port"); +DEFINE_string(cert_file, "", "Certificate file to use."); +DEFINE_string(key_file, "", "Key file to use."); +DEFINE_string(ca_file, "", "CA file to use."); +DEFINE_bool(verify_peer, false, "Set to true to verify the peer."); + +struct EchoData { + std::atomic alive{true}; +}; + +class EchoSession { + public: + EchoSession(EchoData &data, communication::InputStream &input_stream, + communication::OutputStream &output_stream) + : data_(data), + input_stream_(input_stream), + output_stream_(output_stream) {} + + void Execute() { + if (input_stream_.size() < 2) return; + const uint8_t *data = input_stream_.data(); + uint16_t size = *reinterpret_cast(input_stream_.data()); + input_stream_.Resize(size + 2); + if (input_stream_.size() < size + 2) return; + if (size == 0) { + LOG(INFO) << "Server received EOF message"; + data_.alive.store(false); + return; + } + LOG(INFO) << "Server received '" + << std::string(reinterpret_cast(data + 2), size) + << "'"; + if (!output_stream_.Write(data + 2, size)) { + throw utils::BasicException("Output stream write failed!"); + } + input_stream_.Shift(size + 2); + } + + private: + EchoData &data_; + communication::InputStream &input_stream_; + communication::OutputStream &output_stream_; +}; + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + communication::Init(); + + // Initialize the server. + EchoData echo_data; + io::network::Endpoint endpoint(FLAGS_address, FLAGS_port); + communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file, + FLAGS_ca_file, FLAGS_verify_peer); + communication::Server server(endpoint, echo_data, + &context, -1, "SSL", 1); + + while (echo_data.alive) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + return 0; +} diff --git a/tests/stress/.gitignore b/tests/stress/.gitignore index d9d22cbe9..9f14c5c5e 100644 --- a/tests/stress/.gitignore +++ b/tests/stress/.gitignore @@ -1 +1,2 @@ .long_running_stats +*.pem diff --git a/tests/stress/apollo_runs.yaml b/tests/stress/apollo_runs.yaml index 89e64370a..50fa6ca29 100644 --- a/tests/stress/apollo_runs.yaml +++ b/tests/stress/apollo_runs.yaml @@ -10,6 +10,10 @@ commands: TIMEOUT=600 ./continuous_integration --properties-on-disk infiles: *STRESS_INFILES +- name: stress_ssl + commands: TIMEOUT=600 ./continuous_integration --use-ssl + infiles: *STRESS_INFILES + - name: stress_large project: release commands: TIMEOUT=43200 ./continuous_integration --large-dataset diff --git a/tests/stress/common.py b/tests/stress/common.py index 32f4da136..da9158695 100644 --- a/tests/stress/common.py +++ b/tests/stress/common.py @@ -146,14 +146,14 @@ def connection_argument_parser(): ''' parser = ArgumentParser(description=__doc__) - parser.add_argument('--endpoint', type=str, default='localhost:7687', + parser.add_argument('--endpoint', type=str, default='127.0.0.1:7687', help='DBMS instance endpoint. ' 'Bolt protocol is the only option.') parser.add_argument('--username', type=str, default='neo4j', help='DBMS instance username.') parser.add_argument('--password', type=int, default='1234', help='DBMS instance password.') - parser.add_argument('--ssl-enabled', action='store_true', + parser.add_argument('--use-ssl', action='store_true', help="Is SSL enabled?") return parser @@ -163,7 +163,7 @@ def bolt_session(url, auth, ssl=False): ''' with wrapper around Bolt session. - :param url: str, e.g. "bolt://localhost:7687" + :param url: str, e.g. "bolt://127.0.0.1:7687" :param auth: auth method, goes directly to the Bolt driver constructor :param ssl: bool, is ssl enabled ''' @@ -183,14 +183,15 @@ def argument_session(args): :return: Bolt session context manager based on program arguments ''' return bolt_session('bolt://' + args.endpoint, - (args.username, str(args.password))) + (args.username, str(args.password)), + args.use_ssl) -def argument_driver(args, ssl=False): +def argument_driver(args): return GraphDatabase.driver( 'bolt://' + args.endpoint, auth=(args.username, str(args.password)), - encrypted=ssl) + encrypted=args.use_ssl) # This class is used to create and cache sessions. Session is cached by args # used to create it and process' pid in which it was created. This makes it easy diff --git a/tests/stress/continuous_integration b/tests/stress/continuous_integration index 124585aa1..6c408ffcf 100755 --- a/tests/stress/continuous_integration +++ b/tests/stress/continuous_integration @@ -73,6 +73,8 @@ BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) BUILD_DIR = os.path.join(BASE_DIR, "build") CONFIG_DIR = os.path.join(BASE_DIR, "config") MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements") +KEY_FILE = os.path.join(SCRIPT_DIR, ".key.pem") +CERT_FILE = os.path.join(SCRIPT_DIR, ".cert.pem") # long running stats file STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats") @@ -132,6 +134,8 @@ parser.add_argument("--python", default = os.path.join(SCRIPT_DIR, "ve3", "bin", "python3"), type = str) parser.add_argument("--large-dataset", action = "store_const", const = True, default = False) +parser.add_argument("--use-ssl", action = "store_const", + const = True, default = False) parser.add_argument("--verbose", action = "store_const", const = True, default = False) args = parser.parse_args() @@ -140,6 +144,14 @@ args = parser.parse_args() if not os.path.exists(args.memgraph): args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph") +# generate temporary SSL certs +if args.use_ssl: + # https://unix.stackexchange.com/questions/104171/create-ssl-certificate-non-interactively + subj = "/C=HR/ST=Zagreb/L=Zagreb/O=Memgraph/CN=db.memgraph.com" + subprocess.run(["openssl", "req", "-new", "-newkey", "rsa:4096", + "-days", "365", "-nodes", "-x509", "-subj", subj, + "-keyout", KEY_FILE, "-out", CERT_FILE], check=True) + # start memgraph cwd = os.path.dirname(args.memgraph) cmd = [args.memgraph, "--num-workers=" + str(THREADS)] @@ -151,6 +163,8 @@ if args.durability_directory: cmd += ["--durability-directory", args.durability_directory] if args.properties_on_disk: cmd += ["--properties-on-disk", "id,x"] +if args.use_ssl: + cmd += ["--cert-file", CERT_FILE, "--key-file", KEY_FILE] proc_mg = subprocess.Popen(cmd, cwd = cwd, env = {"MEMGRAPH_CONFIG": args.config}) time.sleep(1.0) @@ -167,6 +181,8 @@ def cleanup(): runtimes = {} dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET for test in dataset: + if args.use_ssl: + test["options"] += ["--use-ssl"] runtime = run_test(args, **test) runtimes[os.path.splitext(test["test"])[0]] = runtime @@ -176,6 +192,11 @@ ret_mg = proc_mg.wait() if ret_mg != 0: raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg)) +# cleanup certificates +if args.use_ssl: + os.remove(KEY_FILE) + os.remove(CERT_FILE) + # measurements measurements = "" for key, value in runtimes.items(): diff --git a/tests/stress/long_running.cpp b/tests/stress/long_running.cpp index bd7081b00..d70b97e71 100644 --- a/tests/stress/long_running.cpp +++ b/tests/stress/long_running.cpp @@ -8,6 +8,7 @@ #include "utils/timer.hpp" using EndpointT = io::network::Endpoint; +using ClientContextT = communication::ClientContext; using ClientT = communication::bolt::Client; using DecodedValueT = communication::bolt::DecodedValue; using QueryDataT = communication::bolt::QueryData; @@ -17,6 +18,7 @@ DEFINE_string(address, "127.0.0.1", "Server address"); DEFINE_int32(port, 7687, "Server port"); DEFINE_string(username, "", "Username for the database"); DEFINE_string(password, "", "Password for the database"); +DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server."); DEFINE_int32(vertex_count, 0, "The average number of vertices in the graph per worker"); @@ -51,7 +53,7 @@ class GraphSession { } EndpointT endpoint(FLAGS_address, FLAGS_port); - client_ = std::make_unique(); + client_ = std::make_unique(&context_); if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) { throw utils::BasicException("Couldn't connect to server!"); @@ -60,6 +62,7 @@ class GraphSession { private: uint64_t id_; + ClientContextT context_{FLAGS_use_ssl}; std::unique_ptr client_; std::set vertices_; @@ -362,6 +365,8 @@ int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); + communication::Init(); + CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!"; CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!"; @@ -369,7 +374,8 @@ int main(int argc, char **argv) { // create client EndpointT endpoint(FLAGS_address, FLAGS_port); - ClientT client; + ClientContextT context(FLAGS_use_ssl); + ClientT client(&context); if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { throw utils::BasicException("Couldn't connect to server!"); } diff --git a/tests/unit/network_timeouts.cpp b/tests/unit/network_timeouts.cpp index 6285625f3..c68817e79 100644 --- a/tests/unit/network_timeouts.cpp +++ b/tests/unit/network_timeouts.cpp @@ -52,8 +52,9 @@ bool QueryServer(io::network::Socket &socket) { TEST(NetworkTimeouts, InactiveSession) { // Instantiate the server and set the session timeout to 2 seconds. TestData test_data; + communication::ServerContext context; communication::Server server{ - {"127.0.0.1", 0}, test_data, 2, "Test", 1}; + {"127.0.0.1", 0}, test_data, &context, 2, "Test", 1}; // Create the client and connect to the server. io::network::Socket client; diff --git a/tests/unit/socket.cpp b/tests/unit/socket.cpp new file mode 100644 index 000000000..a02fc13ed --- /dev/null +++ b/tests/unit/socket.cpp @@ -0,0 +1,94 @@ +#include +#include +#include + +#include +#include + +#include "io/network/socket.hpp" +#include "utils/timer.hpp" + +TEST(Socket, WaitForReadyRead) { + io::network::Socket server; + ASSERT_TRUE(server.Bind({"127.0.0.1", 0})); + ASSERT_TRUE(server.Listen(1024)); + + std::thread thread([&server] { + io::network::Socket client; + ASSERT_TRUE(client.Connect(server.endpoint())); + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + ASSERT_TRUE(client.Write("test")); + }); + + uint8_t buff[100]; + auto client = server.Accept(); + ASSERT_TRUE(client); + + client->SetNonBlocking(); + + ASSERT_EQ(client->Read(buff, sizeof(buff)), -1); + ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR); + + utils::Timer timer; + ASSERT_TRUE(client->WaitForReadyRead()); + ASSERT_GT(timer.Elapsed().count(), 1.0); + + ASSERT_GT(client->Read(buff, sizeof(buff)), 0); + + thread.join(); +} + +TEST(Socket, WaitForReadyWrite) { + io::network::Socket server; + ASSERT_TRUE(server.Bind({"127.0.0.1", 0})); + ASSERT_TRUE(server.Listen(1024)); + + std::thread thread([&server] { + uint8_t buff[10000]; + io::network::Socket client; + ASSERT_TRUE(client.Connect(server.endpoint())); + client.SetNonBlocking(); + + // Decrease the TCP read buffer. + int len = 1024; + ASSERT_EQ(setsockopt(client.fd(), SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)), + 0); + + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + while (true) { + int ret = client.Read(buff, sizeof(buff)); + if (ret == -1 && errno != EAGAIN && errno != EWOULDBLOCK && + errno != EINTR) { + std::raise(SIGPIPE); + } else if (ret == 0) { + break; + } + } + }); + + auto client = server.Accept(); + ASSERT_TRUE(client); + + client->SetNonBlocking(); + + // Decrease the TCP write buffer. + int len = 1024; + ASSERT_EQ(setsockopt(client->fd(), SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)), + 0); + + utils::Timer timer; + for (int i = 0; i < 1000000; ++i) { + ASSERT_TRUE(client->Write("test")); + } + ASSERT_GT(timer.Elapsed().count(), 1.0); + + client->Close(); + + thread.join(); +} + +int main(int argc, char **argv) { + google::InitGoogleLogging(argv[0]); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}