Implement SSL support for servers and clients

Summary:
This diff implements OpenSSL support in the network stack.
Currently SSL support is only enabled for Bolt connections,
support for RPC connections will be added in another diff.

Reviewers: buda, teon.banek

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1328
This commit is contained in:
Matej Ferencevic 2018-06-20 17:44:47 +02:00
parent 44821a918c
commit 1d448d40ca
55 changed files with 1400 additions and 177 deletions

View File

@ -10,6 +10,7 @@
* Static vertices/edges id generators exposed through the Id Cypher function. * Static vertices/edges id generators exposed through the Id Cypher function.
* Properties on disk added. * Properties on disk added.
* Telemetry added. * Telemetry added.
* SSL support added.
* Add `toString` function to openCypher * Add `toString` function to openCypher
### Bug Fixes and Other Changes ### Bug Fixes and Other Changes

View File

@ -140,6 +140,9 @@ endif()
set(Boost_USE_STATIC_LIBS ON) set(Boost_USE_STATIC_LIBS ON)
find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization) find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization)
# OpenSSL
find_package(OpenSSL REQUIRED)
set(libs_dir ${CMAKE_SOURCE_DIR}/libs) set(libs_dir ${CMAKE_SOURCE_DIR}/libs)
add_subdirectory(libs EXCLUDE_FROM_ALL) add_subdirectory(libs EXCLUDE_FROM_ALL)
@ -320,6 +323,8 @@ set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY}
Contains Memgraph, the graph database. It aims to deliver developers the Contains Memgraph, the graph database. It aims to deliver developers the
speed, simplicity and scale required to build the next generation of speed, simplicity and scale required to build the next generation of
applications driver by real-time connected data.") applications driver by real-time connected data.")
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
set(CPACK_DEBIAN_memgraph_PACKAGE_DEPENDS "openssl (>= 1.1.0)")
# RPM specific # RPM specific
set(CPACK_RPM_PACKAGE_URL https://memgraph.com) set(CPACK_RPM_PACKAGE_URL https://memgraph.com)
@ -335,6 +340,8 @@ set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_SOURCE_DIR}/release/rpm/memgraph.spe
set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database. set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database.
It aims to deliver developers the speed, simplicity and scale required to build It aims to deliver developers the speed, simplicity and scale required to build
the next generation of applications driver by real-time connected data.") the next generation of applications driver by real-time connected data.")
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
set(CPACK_RPM_memgraph_PACKAGE_REQUIRES "openssl >= 1.0.0")
# All variables must be set before including. # All variables must be set before including.
include(CPack) include(CPack)

View File

@ -16,6 +16,12 @@
# Port the server should listen on. # Port the server should listen on.
--port=7687 --port=7687
# Path to a SSL certificate file that should be used.
--cert-file=/etc/memgraph/ssl/cert.pem
# Path to a SSL key file that should be used.
--key-file=/etc/memgraph/ssl/key.pem
# Number of workers used by the Memgraph server. By default, this will be the # Number of workers used by the Memgraph server. By default, this will be the
# number of processing units available on the machine. # number of processing units available on the machine.
# --num-workers=8 # --num-workers=8

View File

@ -13,11 +13,9 @@ from neo4j.v1 import GraphDatabase, basic_auth
# Initialize and configure the driver. # Initialize and configure the driver.
# * provide the correct URL where Memgraph is reachable; # * provide the correct URL where Memgraph is reachable;
# * use an empty user name and password, and # * use an empty user name and password.
# * disable encryption (not supported).
driver = GraphDatabase.driver("bolt://localhost:7687", driver = GraphDatabase.driver("bolt://localhost:7687",
auth=basic_auth("", ""), auth=basic_auth("", ""))
encrypted=False)
# Start a session in which queries are executed. # Start a session in which queries are executed.
session = driver.session() session = driver.session()
@ -51,9 +49,7 @@ The details about Java driver can be found
[on GitHub](https://github.com/neo4j/neo4j-java-driver). [on GitHub](https://github.com/neo4j/neo4j-java-driver).
The example below is equivalent to Python example. Major difference is that The example below is equivalent to Python example. Major difference is that
`Config` object has to be created before the driver construction. Encryption `Config` object has to be created before the driver construction.
has to be disabled by calling `withoutEncryption` method against the `Config`
builder.
```java ```java
import org.neo4j.driver.v1.*; import org.neo4j.driver.v1.*;
@ -64,7 +60,7 @@ import java.util.*;
public class JavaQuickStart { public class JavaQuickStart {
public static void main(String[] args) { public static void main(String[] args) {
// Initialize driver. // Initialize driver.
Config config = Config.build().withoutEncryption().toConfig(); Config config = Config.build().toConfig();
Driver driver = GraphDatabase.driver("bolt://localhost:7687", Driver driver = GraphDatabase.driver("bolt://localhost:7687",
AuthTokens.basic("",""), AuthTokens.basic("",""),
config); config);
@ -93,9 +89,7 @@ public class JavaQuickStart {
The details about Javascript driver can be found The details about Javascript driver can be found
[on GitHub](https://github.com/neo4j/neo4j-javascript-driver). [on GitHub](https://github.com/neo4j/neo4j-javascript-driver).
The Javascript example below is equivalent to Python and Java examples. SSL The Javascript example below is equivalent to Python and Java examples.
can be disabled by passing `{encrypted: 'ENCRYPTION_OFF'}` during the driver
construction.
Here is an example related to `Node.js`. Memgraph doesn't have integrated Here is an example related to `Node.js`. Memgraph doesn't have integrated
support for `WebSocket` which is required during the execution in any web support for `WebSocket` which is required during the execution in any web
@ -109,8 +103,7 @@ proxy port.
```javascript ```javascript
var neo4j = require('neo4j-driver').v1; var neo4j = require('neo4j-driver').v1;
var driver = neo4j.driver("bolt://localhost:7687", var driver = neo4j.driver("bolt://localhost:7687",
neo4j.auth.basic("neo4j", "1234"), neo4j.auth.basic("neo4j", "1234"));
{encrypted: 'ENCRYPTION_OFF'});
var session = driver.session(); var session = driver.session();
function die() { function die() {
@ -146,8 +139,7 @@ run_query("MATCH (n) DETACH DELETE n", function (result) {
The C# driver is hosted The C# driver is hosted
[on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below [on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below
performs the same work as all of the previous examples. Encryption is disabled performs the same work as all of the previous examples.
by setting `EncryptionLevel.NONE` on the `Config`.
```csh ```csh
using System; using System;
@ -158,7 +150,6 @@ public class Basic {
public static void Main(string[] args) { public static void Main(string[] args) {
// Initialize the driver. // Initialize the driver.
var config = Config.DefaultConfig; var config = Config.DefaultConfig;
config.EncryptionLevel = EncryptionLevel.None;
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config)) using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config))
using(var session = driver.Session()) using(var session = driver.Session())
{ {
@ -176,6 +167,18 @@ public class Basic {
} }
``` ```
### Secure Sockets Layer (SSL)
Secure connections are supported and enabled by default. The server initially
ships with a self-signed testing certificate. The certificate can be replaced
by editing the following parameters in `/etc/memgraph/memgraph.conf`:
```
--cert-file=/path/to/ssl/certificate.pem
--key-file=/path/to/ssl/privatekey.pem
```
To disable SSL support and use insecure connections to the database you should
set both parameters (`--cert-file` and `--key-file`) to empty values.
### Limitations ### Limitations
Memgraph is currently in early stage, and has a number of limitations we plan Memgraph is currently in early stage, and has a number of limitations we plan
@ -186,9 +189,3 @@ to remove in future versions.
Memgraph is currently single-user only. There is no way to control user Memgraph is currently single-user only. There is no way to control user
privileges. The default user has read and write privileges over the whole privileges. The default user has read and write privileges over the whole
database. database.
#### Secure Sockets Layer (SSL)
Secure connections are not supported. For this reason each client
driver needs to be configured not to use encryption. Consult driver-specific
guides for details.

View File

@ -183,7 +183,7 @@ After installing `neo4j-client`, connect to the running Memgraph instance by
issuing the following shell command. issuing the following shell command.
```bash ```bash
neo4j-client --insecure -u "" -p "" localhost 7687 neo4j-client -u "" -p "" localhost 7687
``` ```
After the client has started it should present a command prompt similar to: After the client has started it should present a command prompt similar to:
@ -191,7 +191,7 @@ After the client has started it should present a command prompt similar to:
```bash ```bash
neo4j-client 2.1.3 neo4j-client 2.1.3
Enter `:help` for usage hints. Enter `:help` for usage hints.
Connected to 'neo4j://@localhost:7687' (insecure) Connected to 'neo4j://@localhost:7687'
neo4j> neo4j>
``` ```

1
init
View File

@ -8,6 +8,7 @@ required_pkgs=(git arcanist # source code control
curl wget # for downloading libs curl wget # for downloading libs
uuid-dev default-jre-headless # required by antlr uuid-dev default-jre-headless # required by antlr
libreadline-dev # for memgraph console libreadline-dev # for memgraph console
libssl-dev
libboost-iostreams-dev libboost-iostreams-dev
libboost-serialization-dev libboost-serialization-dev
python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests

View File

@ -29,6 +29,16 @@ case "$1" in
chmod 750 /var/log/memgraph || exit 1 chmod 750 /var/log/memgraph || exit 1
# Make examples directory immutable (optional) # Make examples directory immutable (optional)
chattr +i -R /usr/share/memgraph/examples || true chattr +i -R /usr/share/memgraph/examples || true
# Generate SSL certificates
if [ ! -d /etc/memgraph/ssl ]; then
mkdir /etc/memgraph/ssl || exit 1
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
chmod 400 /etc/memgraph/ssl/* || exit 1
fi
;; ;;
abort-upgrade|abort-remove|abort-deconfigure) abort-upgrade|abort-remove|abort-deconfigure)

View File

@ -71,6 +71,16 @@ chown memgraph:adm /var/log/memgraph || exit 1
chmod 750 /var/log/memgraph || exit 1 chmod 750 /var/log/memgraph || exit 1
# Make examples directory immutable (optional) # Make examples directory immutable (optional)
chattr +i -R /usr/share/memgraph/examples || true chattr +i -R /usr/share/memgraph/examples || true
# Generate SSL certificates
if [ ! -d /etc/memgraph/ssl ]; then
mkdir /etc/memgraph/ssl || exit 1
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
chmod 400 /etc/memgraph/ssl/* || exit 1
fi
@RPM_SYMLINK_POSTINSTALL@ @RPM_SYMLINK_POSTINSTALL@
@CPACK_RPM_SPEC_POSTINSTALL@ @CPACK_RPM_SPEC_POSTINSTALL@

View File

@ -8,6 +8,10 @@ add_subdirectory(telemetry)
# all memgraph src files # all memgraph src files
set(memgraph_src_files set(memgraph_src_files
communication/buffer.cpp communication/buffer.cpp
communication/client.cpp
communication/context.cpp
communication/helpers.cpp
communication/init.cpp
communication/bolt/v1/decoder/decoded_value.cpp communication/bolt/v1/decoder/decoded_value.cpp
communication/rpc/client.cpp communication/rpc/client.cpp
communication/rpc/protocol.cpp communication/rpc/protocol.cpp
@ -189,6 +193,7 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
# memgraph_lib depend on these libraries # memgraph_lib depend on these libraries
set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools
antlr_opencypher_parser_lib dl glog gflags capnp kj antlr_opencypher_parser_lib dl glog gflags capnp kj
${OPENSSL_LIBRARIES}
${Boost_IOSTREAMS_LIBRARY_RELEASE} ${Boost_IOSTREAMS_LIBRARY_RELEASE}
${Boost_SERIALIZATION_LIBRARY_RELEASE} ${Boost_SERIALIZATION_LIBRARY_RELEASE}
mg-utils mg-io) mg-utils mg-io)
@ -206,6 +211,7 @@ endif()
# STATIC library used by memgraph executables # STATIC library used by memgraph executables
add_library(memgraph_lib STATIC ${memgraph_src_files}) add_library(memgraph_lib STATIC ${memgraph_src_files})
target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS}) target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS})
target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR})
add_dependencies(memgraph_lib generate_opencypher_parser) add_dependencies(memgraph_lib generate_opencypher_parser)
add_dependencies(memgraph_lib generate_lcp) add_dependencies(memgraph_lib generate_lcp)
add_dependencies(memgraph_lib generate_capnp) add_dependencies(memgraph_lib generate_capnp)

View File

@ -32,9 +32,9 @@ struct QueryData {
std::map<std::string, DecodedValue> metadata; std::map<std::string, DecodedValue> metadata;
}; };
class Client { class Client final {
public: public:
Client() {} explicit Client(communication::ClientContext *context) : client_(context) {}
Client(const Client &) = delete; Client(const Client &) = delete;
Client(Client &&) = delete; Client(Client &&) = delete;

View File

@ -20,7 +20,7 @@ namespace communication {
* stack where all execution when it is being done is being done on a single * stack where all execution when it is being done is being done on a single
* thread. * thread.
*/ */
class Buffer { class Buffer final {
private: private:
// Initial capacity of the internal buffer. // Initial capacity of the internal buffer.
const size_t kBufferInitialSize = 65536; const size_t kBufferInitialSize = 65536;
@ -28,6 +28,11 @@ class Buffer {
public: public:
Buffer(); Buffer();
Buffer(const Buffer &) = delete;
Buffer(Buffer &&) = delete;
Buffer &operator=(const Buffer &) = delete;
Buffer &operator=(Buffer &&) = delete;
/** /**
* This class provides all functions from the buffer that are needed to allow * This class provides all functions from the buffer that are needed to allow
* reading data from the buffer. * reading data from the buffer.
@ -36,6 +41,11 @@ class Buffer {
public: public:
ReadEnd(Buffer &buffer); ReadEnd(Buffer &buffer);
ReadEnd(const ReadEnd &) = delete;
ReadEnd(ReadEnd &&) = delete;
ReadEnd &operator=(const ReadEnd &) = delete;
ReadEnd &operator=(ReadEnd &&) = delete;
uint8_t *data(); uint8_t *data();
size_t size() const; size_t size() const;
@ -58,6 +68,11 @@ class Buffer {
public: public:
WriteEnd(Buffer &buffer); WriteEnd(Buffer &buffer);
WriteEnd(const WriteEnd &) = delete;
WriteEnd(WriteEnd &&) = delete;
WriteEnd &operator=(const WriteEnd &) = delete;
WriteEnd &operator=(WriteEnd &&) = delete;
io::network::StreamBuffer Allocate(); io::network::StreamBuffer Allocate();
void Written(size_t len); void Written(size_t len);

View File

@ -0,0 +1,232 @@
#include <glog/logging.h>
#include "communication/client.hpp"
#include "communication/helpers.hpp"
namespace communication {
Client::Client(ClientContext *context) : context_(context) {}
Client::~Client() {
Close();
ReleaseSslObjects();
}
bool Client::Connect(const io::network::Endpoint &endpoint) {
// Try to establish a socket connection.
if (!socket_.Connect(endpoint)) return false;
// Enable TCP keep alive for all connections.
// Because we manually always set the `have_more` flag to the socket
// `Write` call we can disable the Nagle algorithm because we know that we
// are always sending optimal packets. Even if we don't send optimal
// packets, there will be no delay between packets and throughput won't
// suffer.
socket_.SetKeepAlive();
socket_.SetNoDelay();
if (context_->use_ssl()) {
// Release leftover SSL objects.
ReleaseSslObjects();
// Create a new SSL object that will be used for SSL communication.
ssl_ = SSL_new(context_->context());
if (ssl_ == nullptr) {
LOG(WARNING) << "Couldn't create client SSL object!";
socket_.Close();
return false;
}
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
// it doesn't need to close the socket when destructing all objects (we
// handle that in our socket destructor).
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
if (bio_ == nullptr) {
LOG(WARNING) << "Couldn't create client BIO object!";
socket_.Close();
return false;
}
// Connect the BIO object to the SSL object so that OpenSSL knows which
// stream it should use for communication. We use the same object for both
// the read and write end. This function cannot fail.
SSL_set_bio(ssl_, bio_, bio_);
// Clear all leftover errors.
ERR_clear_error();
// Perform the TLS handshake.
auto ret = SSL_connect(ssl_);
if (ret != 1) {
LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
socket_.Close();
return false;
}
}
return true;
}
bool Client::ErrorStatus() { return socket_.ErrorStatus(); }
void Client::Shutdown() { socket_.Shutdown(); }
void Client::Close() {
if (ssl_) {
// Perform an unidirectional SSL shutdown. That just means that we send
// the shutdown message and don't wait or care for the result.
SSL_shutdown(ssl_);
}
socket_.Close();
}
bool Client::Read(size_t len) {
size_t received = 0;
buffer_.write_end().Resize(buffer_.read_end().size() + len);
while (received < len) {
auto buff = buffer_.write_end().Allocate();
if (ssl_) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Read encrypted data from the socket using OpenSSL.
auto got = SSL_read(ssl_, buff.data, len - received);
// Handle errors that might have occurred.
if (got < 0) {
auto err = SSL_get_error(ssl_, got);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL want's to read more data from the socket. We wait for
// more data to be ready and retry the call.
socket_.WaitForReadyRead();
continue;
} else if (err == SSL_ERROR_WANT_WRITE) {
// The OpenSSL library probably wants to perform some kind of
// handshake so we wait for the socket to become ready for a write
// and call the read again.
socket_.WaitForReadyWrite();
continue;
} else {
// This is a fatal error.
LOG(WARNING) << "Received an unexpected SSL error: " << err;
return false;
}
} else if (got == 0) {
// The server closed the connection.
return false;
}
// Notify the buffer that it has new data.
buffer_.write_end().Written(got);
received += got;
} else {
// Read raw data from the socket.
auto got = socket_.Read(buff.data, len - received);
if (got <= 0) {
// If `read` returns 0 the server has closed the connection. If `read`
// returns -1 all of the errors that could be found in `errno` are
// fatal errors (because we are using a blocking socket) so return a
// read failure.
return false;
}
// Notify the buffer that it has new data.
buffer_.write_end().Written(got);
received += got;
}
}
return true;
}
uint8_t *Client::GetData() { return buffer_.read_end().data(); }
size_t Client::GetDataSize() { return buffer_.read_end().size(); }
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); }
void Client::ClearData() { buffer_.read_end().Clear(); }
bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
if (ssl_) {
// `SSL_write` has the interface of a normal `write` call. Because of that
// we need to ensure that all data is written to the socket manually.
while (len > 0) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Write data to the socket using OpenSSL.
auto written = SSL_write(ssl_, data, len);
if (written < 0) {
auto err = SSL_get_error(ssl_, written);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL wants to perform some kind of handshake, we need to
// ensure that there is data available for the next call to
// `SSL_write`.
socket_.WaitForReadyRead();
} else if (err == SSL_ERROR_WANT_WRITE) {
// The socket probably returned WOULDBLOCK and we need to wait for
// the output buffers to clear and reattempt the send.
socket_.WaitForReadyWrite();
} else {
// This is a fatal error.
return false;
}
} else if (written == 0) {
// The client closed the connection.
return false;
} else {
len -= written;
data += written;
}
}
return true;
} else {
return socket_.Write(data, len, have_more);
}
}
bool Client::Write(const std::string &str, bool have_more) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
have_more);
}
const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); }
void Client::ReleaseSslObjects() {
// If we are using SSL we need to free the allocated objects. Here we only
// free the SSL object because the `SSL_free` function also automatically
// frees the BIO object.
if (ssl_) {
SSL_free(ssl_);
ssl_ = nullptr;
bio_ = nullptr;
}
}
ClientInputStream::ClientInputStream(Client &client) : client_(client) {}
uint8_t *ClientInputStream::data() { return client_.GetData(); }
size_t ClientInputStream::size() const { return client_.GetDataSize(); }
void ClientInputStream::Shift(size_t len) { client_.ShiftData(len); }
void ClientInputStream::Clear() { client_.ClearData(); }
ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {}
bool ClientOutputStream::Write(const uint8_t *data, size_t len,
bool have_more) {
return client_.Write(data, len, have_more);
}
bool ClientOutputStream::Write(const std::string &str, bool have_more) {
return client_.Write(str, have_more);
}
} // namespace communication

View File

@ -1,6 +1,12 @@
#pragma once #pragma once
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "communication/buffer.hpp" #include "communication/buffer.hpp"
#include "communication/context.hpp"
#include "communication/init.hpp"
#include "io/network/endpoint.hpp" #include "io/network/endpoint.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
@ -10,106 +16,115 @@ namespace communication {
* This class implements a generic network Client. * This class implements a generic network Client.
* It uses blocking sockets and provides an API that can be used to receive/send * It uses blocking sockets and provides an API that can be used to receive/send
* data over the network connection. * data over the network connection.
*
* NOTE: If you use this client you **must** call `communication::Init()` from
* the `main` function before using the client!
*/ */
class Client { class Client final {
public: public:
explicit Client(ClientContext *context);
~Client();
Client(const Client &) = delete;
Client(Client &&) = delete;
Client &operator=(const Client &) = delete;
Client &operator=(Client &&) = delete;
/** /**
* This function connects to a remote server and returns whether the connect * This function connects to a remote server and returns whether the connect
* succeeded. * succeeded.
*/ */
bool Connect(const io::network::Endpoint &endpoint) { bool Connect(const io::network::Endpoint &endpoint);
if (!socket_.Connect(endpoint)) return false;
socket_.SetKeepAlive();
socket_.SetNoDelay();
return true;
}
/** /**
* This function returns `true` if the socket is in an error state. * This function returns `true` if the socket is in an error state.
*/ */
bool ErrorStatus() { return socket_.ErrorStatus(); } bool ErrorStatus();
/** /**
* This function shuts down the socket. * This function shuts down the socket.
*/ */
void Shutdown() { socket_.Shutdown(); } void Shutdown();
/** /**
* This function closes the socket. * This function closes the socket.
*/ */
void Close() { socket_.Close(); } void Close();
/** /**
* This function is used to receive `len` bytes from the socket and stores it * This function is used to receive `len` bytes from the socket and stores it
* in an internal buffer. It returns `true` if the read succeeded and `false` * in an internal buffer. It returns `true` if the read succeeded and `false`
* if it didn't. * if it didn't.
*/ */
bool Read(size_t len) { bool Read(size_t len);
size_t received = 0;
buffer_.write_end().Resize(buffer_.read_end().size() + len);
while (received < len) {
auto buff = buffer_.write_end().Allocate();
int got = socket_.Read(buff.data, len - received);
if (got <= 0) return false;
buffer_.write_end().Written(got);
received += got;
}
return true;
}
/** /**
* This function returns a pointer to the read data that is currently stored * This function returns a pointer to the read data that is currently stored
* in the client. * in the client.
*/ */
uint8_t *GetData() { return buffer_.read_end().data(); } uint8_t *GetData();
/** /**
* This function returns the size of the read data that is currently stored in * This function returns the size of the read data that is currently stored in
* the client. * the client.
*/ */
size_t GetDataSize() { return buffer_.read_end().size(); } size_t GetDataSize();
/** /**
* This function removes first `len` bytes from the data buffer. * This function removes first `len` bytes from the data buffer.
*/ */
void ShiftData(size_t len) { buffer_.read_end().Shift(len); } void ShiftData(size_t len);
/** /**
* This function clears the data buffer. * This function clears the data buffer.
*/ */
void ClearData() { buffer_.read_end().Clear(); } void ClearData();
// Write end /**
bool Write(const uint8_t *data, size_t len, bool have_more = false) { * This function writes data to the socket.
return socket_.Write(data, len, have_more); * TODO (mferencevic): the `have_more` flag currently isn't supported when
} * using OpenSSL
bool Write(const std::string &str, bool have_more = false) { */
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(), bool Write(const uint8_t *data, size_t len, bool have_more = false);
have_more);
}
const io::network::Endpoint &endpoint() { return socket_.endpoint(); } /**
* This function writes data to the socket.
*/
bool Write(const std::string &str, bool have_more = false);
const io::network::Endpoint &endpoint();
private: private:
io::network::Socket socket_; void ReleaseSslObjects();
io::network::Socket socket_;
Buffer buffer_; Buffer buffer_;
ClientContext *context_;
SSL *ssl_{nullptr};
BIO *bio_{nullptr};
}; };
/** /**
* This class provides a stream-like input side object to the client. * This class provides a stream-like input side object to the client.
*/ */
class ClientInputStream { class ClientInputStream final {
public: public:
ClientInputStream(Client &client) : client_(client) {} ClientInputStream(Client &client);
uint8_t *data() { return client_.GetData(); } ClientInputStream(const ClientInputStream &) = delete;
ClientInputStream(ClientInputStream &&) = delete;
ClientInputStream &operator=(const ClientInputStream &) = delete;
ClientInputStream &operator=(ClientInputStream &&) = delete;
size_t size() const { return client_.GetDataSize(); } uint8_t *data();
void Shift(size_t len) { client_.ShiftData(len); } size_t size() const;
void Clear() { client_.ClearData(); } void Shift(size_t len);
void Clear();
private: private:
Client &client_; Client &client_;
@ -118,16 +133,18 @@ class ClientInputStream {
/** /**
* This class provides a stream-like output side object to the client. * This class provides a stream-like output side object to the client.
*/ */
class ClientOutputStream { class ClientOutputStream final {
public: public:
ClientOutputStream(Client &client) : client_(client) {} ClientOutputStream(Client &client);
bool Write(const uint8_t *data, size_t len, bool have_more = false) { ClientOutputStream(const ClientOutputStream &) = delete;
return client_.Write(data, len, have_more); ClientOutputStream(ClientOutputStream &&) = delete;
} ClientOutputStream &operator=(const ClientOutputStream &) = delete;
bool Write(const std::string &str, bool have_more = false) { ClientOutputStream &operator=(ClientOutputStream &&) = delete;
return client_.Write(str, have_more);
} bool Write(const uint8_t *data, size_t len, bool have_more = false);
bool Write(const std::string &str, bool have_more = false);
private: private:
Client &client_; Client &client_;

View File

@ -0,0 +1,80 @@
#include <glog/logging.h>
#include "communication/context.hpp"
namespace communication {
ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
if (use_ssl_) {
ctx_ = SSL_CTX_new(TLS_client_method());
CHECK(ctx_ != nullptr) << "Couldn't create client SSL_CTX object!";
// Disable legacy SSL support. Other options can be seen here:
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
}
}
ClientContext::ClientContext(const std::string &key_file,
const std::string &cert_file)
: ClientContext(true) {
if (key_file != "" && cert_file != "") {
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
SSL_FILETYPE_PEM) == 1)
<< "Couldn't load client certificate from file: " << cert_file;
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
SSL_FILETYPE_PEM) == 1)
<< "Couldn't load client private key from file: " << key_file;
}
}
SSL_CTX *ClientContext::context() { return ctx_; }
bool ClientContext::use_ssl() { return use_ssl_; }
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
ServerContext::ServerContext(const std::string &key_file,
const std::string &cert_file,
const std::string &ca_file, bool verify_peer)
: use_ssl_(true), ctx_(SSL_CTX_new(TLS_server_method())) {
// TODO (mferencevic): add support for encrypted private keys
// TODO (mferencevic): add certificate revocation list (CRL)
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
SSL_FILETYPE_PEM) == 1)
<< "Couldn't load server certificate from file: " << cert_file;
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) ==
1)
<< "Couldn't load server private key from file: " << key_file;
// Disable legacy SSL support. Other options can be seen here:
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
if (ca_file != "") {
// Load the certificate authority file.
CHECK(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1)
<< "Couldn't load certificate authority from file: " << ca_file;
if (verify_peer) {
// Add the CA to list of accepted CAs that is sent to the client.
STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str());
CHECK(ca_names != nullptr)
<< "Couldn't load certificate authority from file: " << ca_file;
// `ca_names` doesn' need to be free'd because we pass it to
// `SSL_CTX_set_client_CA_list`:
// https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html
SSL_CTX_set_client_CA_list(ctx_, ca_names);
// Enable verification of the client certificate.
SSL_CTX_set_verify(
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
}
}
}
SSL_CTX *ServerContext::context() { return ctx_; }
bool ServerContext::use_ssl() { return use_ssl_; }
} // namespace communication

View File

@ -0,0 +1,63 @@
#pragma once
#include <string>
#include <openssl/ssl.h>
namespace communication {
class ClientContext final {
public:
/**
* This constructor constructs a ClientContext that can either not use SSL
* (`use_ssl` is `false` by default), or it constructs a ClientContext that
* doesn't use a client certificate when `use_ssl` is set to `true`.
*/
explicit ClientContext(bool use_ssl = false);
/**
* This constructor constructs a ClientContext that uses SSL and uses the
* specific client private key and certificate combination. If the parameters
* `key_file` and `cert_file` are equal to "" then the constructor falls back
* to the above constructor that uses SSL without certificates.
*/
ClientContext(const std::string &key_file, const std::string &cert_file);
SSL_CTX *context();
bool use_ssl();
private:
bool use_ssl_;
SSL_CTX *ctx_;
};
class ServerContext final {
public:
/**
* This constructor constructs a ServerContext that doesn't use SSL.
*/
ServerContext();
/**
* This constructor constructs a ServerContext that uses SSL. The parameters
* `key_file` and `cert_file` can't be "" because when setting up a server it
* is mandatory to supply a private key and certificate. The parameter
* `ca_file` can be "" because SSL doesn't necessarily need to check that the
* client has a valid certificate. If you specify `verify_peer` to be `true`
* to check that the client certificate is valid, then you need to supply a
* valid `ca_file` as well.
*/
ServerContext(const std::string &key_file, const std::string &cert_file,
const std::string &ca_file = "", bool verify_peer = false);
SSL_CTX *context();
bool use_ssl();
private:
bool use_ssl_;
SSL_CTX *ctx_;
};
} // namespace communication

View File

@ -0,0 +1,13 @@
#include <openssl/err.h>
#include "communication/helpers.hpp"
namespace communication {
const std::string SslGetLastError() {
char buff[2048];
auto err = ERR_get_error();
ERR_error_string_n(err, buff, sizeof(buff));
return std::string(buff);
}
} // namespace communication

View File

@ -0,0 +1,12 @@
#pragma once
#include <string>
namespace communication {
/**
* This function reads and returns a string describing the last OpenSSL error.
*/
const std::string SslGetLastError();
} // namespace communication

View File

@ -0,0 +1,21 @@
#include <glog/logging.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "utils/signals.hpp"
namespace communication {
void Init() {
// Initialize the OpenSSL library.
SSL_library_init();
OpenSSL_add_ssl_algorithms();
SSL_load_error_strings();
ERR_load_crypto_strings();
// Ignore SIGPIPE.
CHECK(utils::SignalIgnore(utils::Signal::Pipe)) << "Couldn't ignore SIGPIPE!";
}
} // namespace communication

View File

@ -0,0 +1,16 @@
#pragma once
namespace communication {
/**
* Call this function in each `main` file that uses the Communication stack. It
* is used to initialize all libraries (primarily OpenSSL) and to fix some
* issues also related to OpenSSL (handling of SIGPIPE).
*
* Description of OpenSSL init can be seen here:
* https://wiki.openssl.org/index.php/Library_Initialization
*
* NOTE: This function must be called **exactly** once.
*/
void Init();
} // namespace communication

View File

@ -13,6 +13,7 @@
#include "communication/session.hpp" #include "communication/session.hpp"
#include "io/network/epoll.hpp" #include "io/network/epoll.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
#include "utils/signals.hpp"
#include "utils/thread.hpp" #include "utils/thread.hpp"
#include "utils/thread/sync.hpp" #include "utils/thread/sync.hpp"
@ -28,7 +29,7 @@ namespace communication {
* expired. * expired.
*/ */
template <class TSession, class TSessionData> template <class TSession, class TSessionData>
class Listener { class Listener final {
private: private:
// The maximum number of events handled per execution thread is 1. This is // The maximum number of events handled per execution thread is 1. This is
// because each event represents the start of a network request and it doesn't // because each event represents the start of a network request and it doesn't
@ -39,14 +40,27 @@ class Listener {
using SessionHandler = Session<TSession, TSessionData>; using SessionHandler = Session<TSession, TSessionData>;
public: public:
Listener(TSessionData &data, int inactivity_timeout_sec, Listener(TSessionData &data, ServerContext *context,
const std::string &service_name) int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count)
: data_(data), : data_(data),
alive_(true), alive_(true),
context_(context),
inactivity_timeout_sec_(inactivity_timeout_sec), inactivity_timeout_sec_(inactivity_timeout_sec),
service_name_(service_name) { service_name_(service_name) {
std::cout << "Starting " << workers_count << " " << service_name_
<< " workers" << std::endl;
for (size_t i = 0; i < workers_count; ++i) {
worker_threads_.emplace_back([this, service_name, i]() {
utils::ThreadSetName(fmt::format("{} worker {}", service_name, i + 1));
while (alive_) {
WaitAndProcessEvents();
}
});
}
if (inactivity_timeout_sec_ > 0) { if (inactivity_timeout_sec_ > 0) {
thread_ = std::thread([this, service_name]() { timeout_thread_ = std::thread([this, service_name]() {
utils::ThreadSetName(fmt::format("{} timeout", service_name)); utils::ThreadSetName(fmt::format("{} timeout", service_name));
while (alive_) { while (alive_) {
{ {
@ -72,7 +86,10 @@ class Listener {
~Listener() { ~Listener() {
alive_.store(false); alive_.store(false);
if (thread_.joinable()) thread_.join(); if (timeout_thread_.joinable()) timeout_thread_.join();
for (auto &worker_thread : worker_threads_) {
worker_thread.join();
}
} }
Listener(const Listener &) = delete; Listener(const Listener &) = delete;
@ -88,20 +105,12 @@ class Listener {
void AddConnection(io::network::Socket &&connection) { void AddConnection(io::network::Socket &&connection) {
std::unique_lock<utils::SpinLock> guard(lock_); std::unique_lock<utils::SpinLock> guard(lock_);
// Set connection options.
// The socket is left to be a blocking socket, but when `Read` is called
// then a flag is manually set to enable non-blocking read that is used in
// conjunction with `EPOLLET`. That means that the socket is used in a
// non-blocking fashion for reads and a blocking fashion for writes.
connection.SetKeepAlive();
connection.SetNoDelay();
// Remember fd before moving connection into Session. // Remember fd before moving connection into Session.
int fd = connection.fd(); int fd = connection.fd();
// Create a new Session for the connection. // Create a new Session for the connection.
sessions_.push_back(std::make_unique<SessionHandler>( sessions_.push_back(std::make_unique<SessionHandler>(
std::move(connection), data_, inactivity_timeout_sec_)); std::move(connection), data_, context_, inactivity_timeout_sec_));
// Register the connection in Epoll. // Register the connection in Epoll.
// We want to listen to an incoming event which is edge triggered and // We want to listen to an incoming event which is edge triggered and
@ -113,6 +122,7 @@ class Listener {
sessions_.back().get()); sessions_.back().get());
} }
private:
/** /**
* This function polls the event queue and processes incoming data. * This function polls the event queue and processes incoming data.
* It is thread safe and is intended to be called from multiple threads and * It is thread safe and is intended to be called from multiple threads and
@ -168,7 +178,6 @@ class Listener {
} }
} }
private:
bool ExecuteSession(SessionHandler &session) { bool ExecuteSession(SessionHandler &session) {
try { try {
if (session.Execute()) { if (session.Execute()) {
@ -223,8 +232,11 @@ class Listener {
utils::SpinLock lock_; utils::SpinLock lock_;
std::vector<std::unique_ptr<SessionHandler>> sessions_; std::vector<std::unique_ptr<SessionHandler>> sessions_;
std::thread thread_; std::thread timeout_thread_;
std::vector<std::thread> worker_threads_;
std::atomic<bool> alive_; std::atomic<bool> alive_;
ServerContext *context_;
const int inactivity_timeout_sec_; const int inactivity_timeout_sec_;
const std::string service_name_; const std::string service_name_;
}; };

View File

@ -30,7 +30,7 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
// Connect to the remote server. // Connect to the remote server.
if (!client_) { if (!client_) {
client_.emplace(); client_.emplace(&context_);
if (!client_->Connect(endpoint_)) { if (!client_->Connect(endpoint_)) {
LOG(ERROR) << "Couldn't connect to remote address " << endpoint_; LOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
client_ = std::experimental::nullopt; client_ = std::experimental::nullopt;

View File

@ -86,6 +86,8 @@ class Client {
::capnp::MessageBuilder *message); ::capnp::MessageBuilder *message);
io::network::Endpoint endpoint_; io::network::Endpoint endpoint_;
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
communication::ClientContext context_;
std::experimental::optional<communication::Client> client_; std::experimental::optional<communication::Client> client_;
std::mutex mutex_; std::mutex mutex_;

View File

@ -4,7 +4,7 @@ namespace communication::rpc {
Server::Server(const io::network::Endpoint &endpoint, Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count) size_t workers_count)
: server_(endpoint, *this, -1, "RPC", workers_count) {} : server_(endpoint, *this, &context_, -1, "RPC", workers_count) {}
void Server::StopProcessingCalls() { void Server::StopProcessingCalls() {
server_.Shutdown(); server_.Shutdown();

View File

@ -78,6 +78,8 @@ class Server {
ConcurrentMap<uint64_t, RpcCallback> callbacks_; ConcurrentMap<uint64_t, RpcCallback> callbacks_;
std::mutex mutex_; std::mutex mutex_;
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
communication::ServerContext context_;
communication::Server<Session, Server> server_; communication::Server<Session, Server> server_;
}; // namespace communication::rpc }; // namespace communication::rpc

View File

@ -10,6 +10,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <glog/logging.h> #include <glog/logging.h>
#include "communication/init.hpp"
#include "communication/listener.hpp" #include "communication/listener.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
#include "utils/thread.hpp" #include "utils/thread.hpp"
@ -27,6 +28,9 @@ namespace communication {
* Current Server achitecture: * Current Server achitecture:
* incoming connection -> server -> listener -> session * incoming connection -> server -> listener -> session
* *
* NOTE: If you use this server you **must** call `communication::Init()` from
* the `main` function before using the server!
*
* @tparam TSession the server can handle different Sessions, each session * @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure * represents a different protocol so the same network infrastructure
* can be used for handling different protocols * can be used for handling different protocols
@ -34,7 +38,7 @@ namespace communication {
* session * session
*/ */
template <typename TSession, typename TSessionData> template <typename TSession, typename TSessionData>
class Server { class Server final {
public: public:
using Socket = io::network::Socket; using Socket = io::network::Socket;
@ -43,9 +47,11 @@ class Server {
* invokes workers_count workers * invokes workers_count workers
*/ */
Server(const io::network::Endpoint &endpoint, TSessionData &session_data, Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
int inactivity_timeout_sec, const std::string &service_name, ServerContext *context, int inactivity_timeout_sec,
const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency()) size_t workers_count = std::thread::hardware_concurrency())
: listener_(session_data, inactivity_timeout_sec, service_name), : listener_(session_data, context, inactivity_timeout_sec, service_name,
workers_count),
service_name_(service_name) { service_name_(service_name) {
// Without server we can't continue with application so we can just // Without server we can't continue with application so we can just
// terminate here. // terminate here.
@ -58,18 +64,7 @@ class Server {
} }
thread_ = std::thread([this, workers_count, service_name]() { thread_ = std::thread([this, workers_count, service_name]() {
std::cout << "Starting " << workers_count << " " << service_name
<< " workers" << std::endl;
utils::ThreadSetName(fmt::format("{} server", service_name)); utils::ThreadSetName(fmt::format("{} server", service_name));
for (size_t i = 0; i < workers_count; ++i) {
worker_threads_.emplace_back([this, service_name, i]() {
utils::ThreadSetName(
fmt::format("{} worker {}", service_name, i + 1));
while (alive_) {
listener_.WaitAndProcessEvents();
}
});
}
std::cout << service_name << " server is fully armed and operational" std::cout << service_name << " server is fully armed and operational"
<< std::endl; << std::endl;
@ -81,9 +76,6 @@ class Server {
} }
std::cout << service_name << " shutting down..." << std::endl; std::cout << service_name << " shutting down..." << std::endl;
for (auto &worker_thread : worker_threads_) {
worker_thread.join();
}
}); });
} }
@ -128,7 +120,6 @@ class Server {
std::atomic<bool> alive_{true}; std::atomic<bool> alive_{true};
std::thread thread_; std::thread thread_;
std::vector<std::thread> worker_threads_;
Socket socket_; Socket socket_;
Listener<TSession, TSessionData> listener_; Listener<TSession, TSessionData> listener_;

View File

@ -9,7 +9,13 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "communication/buffer.hpp" #include "communication/buffer.hpp"
#include "communication/context.hpp"
#include "communication/helpers.hpp"
#include "io/network/socket.hpp" #include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp" #include "io/network/stream_buffer.hpp"
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
@ -35,12 +41,19 @@ using InputStream = Buffer::ReadEnd;
* This is used to provide output from user sessions. All sessions used with the * This is used to provide output from user sessions. All sessions used with the
* network stack should use this class for their output stream. * network stack should use this class for their output stream.
*/ */
class OutputStream { class OutputStream final {
public: public:
OutputStream(io::network::Socket &socket) : socket_(socket) {} OutputStream(
std::function<bool(const uint8_t *, size_t, bool)> write_function)
: write_function_(write_function) {}
OutputStream(const OutputStream &) = delete;
OutputStream(OutputStream &&) = delete;
OutputStream &operator=(const OutputStream &) = delete;
OutputStream &operator=(OutputStream &&) = delete;
bool Write(const uint8_t *data, size_t len, bool have_more = false) { bool Write(const uint8_t *data, size_t len, bool have_more = false) {
return socket_.Write(data, len, have_more); return write_function_(data, len, have_more);
} }
bool Write(const std::string &str, bool have_more = false) { bool Write(const std::string &str, bool have_more = false) {
@ -49,7 +62,7 @@ class OutputStream {
} }
private: private:
io::network::Socket &socket_; std::function<bool(const uint8_t *, size_t, bool)> write_function_;
}; };
/** /**
@ -58,20 +71,73 @@ class OutputStream {
* wrapping. * wrapping.
*/ */
template <class TSession, class TSessionData> template <class TSession, class TSessionData>
class Session { class Session final {
public: public:
Session(io::network::Socket &&socket, TSessionData &data, Session(io::network::Socket &&socket, TSessionData &data,
int inactivity_timeout_sec) ServerContext *context, int inactivity_timeout_sec)
: socket_(std::move(socket)), : socket_(std::move(socket)),
output_stream_(socket_), output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
return Write(data, len, have_more);
}),
session_(data, input_buffer_.read_end(), output_stream_), session_(data, input_buffer_.read_end(), output_stream_),
inactivity_timeout_sec_(inactivity_timeout_sec) {} inactivity_timeout_sec_(inactivity_timeout_sec) {
// Set socket options.
// The socket is set to be a non-blocking socket. We use the socket in a
// non-blocking fashion for reads and manually simulate a blocking socket
// type for writes. This manual handling of writes is necessary because
// OpenSSL doesn't provide a way to add `recv` parameters to the `SSL_read`
// call so we can't have a blocking socket and use it in a non-blocking way
// only for reads.
// Keep alive is enabled so that the Kernel's TCP stack notifies us if a
// connection is broken and shouldn't be used anymore.
// Because we manually always set the `have_more` flag to the socket
// `Write` call we can disable the Nagle algorithm because we know that we
// are always sending optimal packets. Even if we don't send optimal
// packets, there will be no delay between packets and throughput won't
// suffer.
socket_.SetNonBlocking();
socket_.SetKeepAlive();
socket_.SetNoDelay();
// Prepare SSL if we should be using it.
if (context->use_ssl()) {
// Create a new SSL object that will be used for SSL communication.
ssl_ = SSL_new(context->context());
CHECK(ssl_ != nullptr) << "Couldn't create server SSL object!";
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
// it doesn't need to close the socket when destructing all objects (we
// handle that in our socket destructor).
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
CHECK(bio_ != nullptr) << "Couldn't create server BIO object!";
// Connect the BIO object to the SSL object so that OpenSSL knows which
// stream it should use for communication. We use the same object for both
// the read and write end. This function cannot fail.
SSL_set_bio(ssl_, bio_, bio_);
// Indicate to OpenSSL that this connection is a server. The TLS handshake
// will be performed in the first `SSL_read` or `SSL_write` call. This
// function cannot fail.
SSL_set_accept_state(ssl_);
}
}
Session(const Session &) = delete; Session(const Session &) = delete;
Session(Session &&) = delete; Session(Session &&) = delete;
Session &operator=(const Session &) = delete; Session &operator=(const Session &) = delete;
Session &operator=(Session &&) = delete; Session &operator=(Session &&) = delete;
~Session() {
// If we are using SSL we need to free the allocated objects. Here we only
// free the SSL object because the `SSL_free` function also automatically
// frees the BIO object.
if (ssl_) {
SSL_free(ssl_);
}
}
/** /**
* This function is called from the communication stack when an event occurs * This function is called from the communication stack when an event occurs
* indicating that there is data waiting to be read. This function calls the * indicating that there is data waiting to be read. This function calls the
@ -87,29 +153,69 @@ class Session {
// Allocate the buffer to fill the data. // Allocate the buffer to fill the data.
auto buf = input_buffer_.write_end().Allocate(); auto buf = input_buffer_.write_end().Allocate();
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
int len = socket_.Read(buf.data, buf.len, true);
// Check for read errors. if (ssl_) {
if (len == -1) { // We clear errors here to prevent errors piling up in the internal
// This means read would block or read was interrupted by signal, we // OpenSSL error queue. To see when could that be an issue read this:
// return `true` to indicate that all data is processad and to stop // https://www.arangodb.com/2014/07/started-hate-openssl/
// reading of data. ERR_clear_error();
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
return true; // Read data from the socket using the OpenSSL API.
auto len = SSL_read(ssl_, buf.data, buf.len);
// Check for read errors.
if (len < 0) {
auto err = SSL_get_error(ssl_, len);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL want's to read more data from the socket. We return `true`
// to stop execution of the session to wait for more data to be
// received.
return true;
} else if (err == SSL_ERROR_WANT_WRITE) {
// The OpenSSL library wants to perfrom some kind of handshake so we
// wait for the socket to become ready for a write and call the read
// again. We return `false` so that the listener calls this function
// again.
socket_.WaitForReadyWrite();
return false;
} else {
// This is a fatal error.
throw utils::BasicException(SslGetLastError());
}
} else if (len == 0) {
// The client closed the connection.
throw SessionClosedException("Session was closed by the client.");
return false;
} else {
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
} }
// Some other error occurred, throw an exception to start session cleanup. } else {
throw utils::BasicException("Couldn't read data from socket!"); // Read from the buffer at most buf.len bytes in a non-blocking fashion.
} // Note, the `true` parameter for non-blocking here is redundant because
// the socket already is non-blocking.
auto len = socket_.Read(buf.data, buf.len, true);
// The client has closed the connection. // Check for read errors.
if (len == 0) { if (len == -1) {
throw SessionClosedException("Session was closed by client."); // This means read would block or read was interrupted by signal, we
// return `true` to indicate that all data is processad and to stop
// reading of data.
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
return true;
}
// Some other error occurred, throw an exception to start session
// cleanup.
throw utils::BasicException("Couldn't read data from the socket!");
} else if (len == 0) {
// The client has closed the connection.
throw SessionClosedException("Session was closed by client.");
} else {
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
}
} }
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
// Execute the session. // Execute the session.
session_.Execute(); session_.Execute();
@ -142,6 +248,52 @@ class Session {
last_event_time_ = std::chrono::steady_clock::now(); last_event_time_ = std::chrono::steady_clock::now();
} }
// TODO (mferencevic): the `have_more` flag currently isn't supported
// when using OpenSSL
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
if (ssl_) {
// `SSL_write` has the interface of a normal `write` call. Because of that
// we need to ensure that all data is written to the socket manually.
while (len > 0) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Write data to the socket using OpenSSL.
auto written = SSL_write(ssl_, data, len);
if (written < 0) {
auto err = SSL_get_error(ssl_, written);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL wants to perform some kind of handshake, we need to
// ensure that there is data available for the next call to
// `SSL_write`.
socket_.WaitForReadyRead();
} else if (err == SSL_ERROR_WANT_WRITE) {
// The socket probably returned WOULDBLOCK and we need to wait for
// the output buffers to clear and reattempt the send.
socket_.WaitForReadyWrite();
} else {
// This is a fatal error.
return false;
}
} else if (written == 0) {
// The client closed the connection.
return false;
} else {
len -= written;
data += written;
}
}
return true;
} else {
// This function guarantees that all data will be written to the socket
// even if the socket is non-blocking. It will use a non-busy wait to send
// all data.
return socket_.Write(data, len, have_more);
}
}
// We own the socket. // We own the socket.
io::network::Socket socket_; io::network::Socket socket_;
@ -157,5 +309,9 @@ class Session {
std::chrono::steady_clock::now()}; std::chrono::steady_clock::now()};
utils::SpinLock lock_; utils::SpinLock lock_;
const int inactivity_timeout_sec_; const int inactivity_timeout_sec_;
};
// SSL objects.
SSL *ssl_{nullptr};
BIO *bio_{nullptr};
}; // namespace communication
} // namespace communication } // namespace communication

View File

@ -11,6 +11,7 @@
#include <netdb.h> #include <netdb.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <netinet/tcp.h> #include <netinet/tcp.h>
#include <poll.h>
#include <sys/epoll.h> #include <sys/epoll.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/types.h> #include <sys/types.h>
@ -219,8 +220,13 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
// Terminal error, return failure. // Terminal error, return failure.
return false; return false;
} }
// Non-fatal error, retry. // Non-fatal error, retry after the socket is ready. This is here to
continue; // implement a non-busy wait. If we just continue with the loop we have a
// busy wait.
if (!WaitForReadyWrite()) return false;
} else if (written == 0) {
// The client closed the connection.
return false;
} else { } else {
len -= written; len -= written;
data += written; data += written;
@ -234,7 +240,32 @@ bool Socket::Write(const std::string &s, bool have_more) {
have_more); have_more);
} }
int Socket::Read(void *buffer, size_t len, bool nonblock) { ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) {
return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0); return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0);
} }
bool Socket::WaitForReadyRead() {
struct pollfd p;
p.fd = socket_;
p.events = POLLIN;
// We call poll with one element in the poll fds array (first and second
// arguments), also we set the timeout to -1 to block indefinitely until an
// event occurs.
int ret = poll(&p, 1, -1);
if (ret < 1) return false;
return p.revents & POLLIN;
}
bool Socket::WaitForReadyWrite() {
struct pollfd p;
p.fd = socket_;
p.events = POLLOUT;
// We call poll with one element in the poll fds array (first and second
// arguments), also we set the timeout to -1 to block indefinitely until an
// event occurs.
int ret = poll(&p, 1, -1);
if (ret < 1) return false;
return p.revents & POLLOUT;
}
} // namespace io::network } // namespace io::network

View File

@ -153,7 +153,41 @@ class Socket {
* == 0 if the client closed the connection * == 0 if the client closed the connection
* < 0 if an error has occurred * < 0 if an error has occurred
*/ */
int Read(void *buffer, size_t len, bool nonblock = false); ssize_t Read(void *buffer, size_t len, bool nonblock = false);
/**
* Wait until the socket becomes ready for a `Read` operation.
* This function blocks indefinitely waiting for the socket to change its
* state. This function is useful when you need a blocking operation on a
* non-blocking socket, you can call this function to ensure that your next
* `Read` operation will succeed.
*
* The function returns `true` if the wait succeded (there is data waiting to
* be read from the socket) and returns `false` if the wait failed (the socket
* was closed or something else bad happened).
*
* @return wait success status:
* true if the wait succeeded
* false if the wait failed
*/
bool WaitForReadyRead();
/**
* Wait until the socket becomes ready for a `Write` operation.
* This function blocks indefinitely waiting for the socket to change its
* state. This function is useful when you need a blocking operation on a
* non-blocking socket, you can call this function to ensure that your next
* `Write` operation will succeed.
*
* The function returns `true` if the wait succeded (the socket can be written
* to) and returns `false` if the wait failed (the socket was closed or
* something else bad happened).
*
* @return wait success status:
* true if the wait succeeded
* false if the wait failed
*/
bool WaitForReadyWrite();
private: private:
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {} Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}

View File

@ -28,6 +28,7 @@ using communication::bolt::SessionData;
using SessionT = communication::bolt::Session<communication::InputStream, using SessionT = communication::bolt::Session<communication::InputStream,
communication::OutputStream>; communication::OutputStream>;
using ServerT = communication::Server<SessionT, SessionData>; using ServerT = communication::Server<SessionT, SessionData>;
using communication::ServerContext;
// General purpose flags. // General purpose flags.
DEFINE_string(interface, "0.0.0.0", DEFINE_string(interface, "0.0.0.0",
@ -41,6 +42,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
"Time in seconds after which inactive sessions will be " "Time in seconds after which inactive sessions will be "
"closed.", "closed.",
FLAG_IN_RANGE(1, INT32_MAX)); FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
DEFINE_string(log_file, "", "Path to where the log should be stored."); DEFINE_string(log_file, "", "Path to where the log should be stored.");
DEFINE_HIDDEN_string( DEFINE_HIDDEN_string(
log_link_basename, "", log_link_basename, "",
@ -142,6 +145,9 @@ int WithInit(int argc, char **argv,
stats::InitStatsLogging(get_stats_prefix()); stats::InitStatsLogging(get_stats_prefix());
utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); }); utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); });
// Initialize the communication library.
communication::Init();
// Start memory warning logger. // Start memory warning logger.
utils::Scheduler mem_log_scheduler; utils::Scheduler mem_log_scheduler;
if (FLAGS_memory_warning_threshold > 0) { if (FLAGS_memory_warning_threshold > 0) {
@ -160,9 +166,17 @@ void SingleNodeMain() {
google::SetUsageMessage("Memgraph single-node database server"); google::SetUsageMessage("Memgraph single-node database server");
database::SingleNode db; database::SingleNode db;
SessionData session_data{db}; SessionData session_data{db};
ServerContext context;
std::string service_name = "Bolt";
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
service_name = "BoltS";
}
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)}, ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, FLAGS_session_inactivity_timeout, "Bolt", session_data, &context, FLAGS_session_inactivity_timeout,
FLAGS_num_workers); service_name, FLAGS_num_workers);
// Setup telemetry // Setup telemetry
std::experimental::optional<telemetry::Telemetry> telemetry; std::experimental::optional<telemetry::Telemetry> telemetry;
@ -214,9 +228,17 @@ void MasterMain() {
database::Master db; database::Master db;
SessionData session_data{db}; SessionData session_data{db};
ServerContext context;
std::string service_name = "Bolt";
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
service_name = "BoltS";
}
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)}, ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, FLAGS_session_inactivity_timeout, "Bolt", session_data, &context, FLAGS_session_inactivity_timeout,
FLAGS_num_workers); service_name, FLAGS_num_workers);
// Handler for regular termination signals // Handler for regular termination signals
auto shutdown = [&server] { auto shutdown = [&server] {

View File

@ -42,10 +42,11 @@ class TestSession {
input_stream_.Shift(size + 2); input_stream_.Shift(size + 2);
} }
communication::InputStream input_stream_; communication::InputStream &input_stream_;
communication::OutputStream output_stream_; communication::OutputStream &output_stream_;
}; };
using ContextT = communication::ServerContext;
using ServerT = communication::Server<TestSession, TestData>; using ServerT = communication::Server<TestSession, TestData>;
void client_run(int num, const char *interface, uint16_t port, void client_run(int num, const char *interface, uint16_t port,

View File

@ -32,8 +32,8 @@ class TestSession {
output_stream_.Write(input_stream_.data(), input_stream_.size()); output_stream_.Write(input_stream_.data(), input_stream_.size());
} }
communication::InputStream input_stream_; communication::InputStream &input_stream_;
communication::OutputStream output_stream_; communication::OutputStream &output_stream_;
}; };
std::atomic<bool> run{true}; std::atomic<bool> run{true};
@ -63,8 +63,9 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
TestData data; TestData data;
int N = (std::thread::hardware_concurrency() + 1) / 2; int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3; int Nc = N * 3;
communication::Server<TestSession, TestData> server(endpoint, data, -1, communication::ServerContext context;
"Test", N); communication::Server<TestSession, TestData> server(endpoint, data, &context,
-1, "Test", N);
const auto &ep = server.endpoint(); const auto &ep = server.endpoint();
// start clients // start clients

View File

@ -21,7 +21,8 @@ TEST(Network, Server) {
// initialize server // initialize server
TestData session_data; TestData session_data;
int N = (std::thread::hardware_concurrency() + 1) / 2; int N = (std::thread::hardware_concurrency() + 1) / 2;
ServerT server(endpoint, session_data, -1, "Test", N); ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", N);
const auto &ep = server.endpoint(); const auto &ep = server.endpoint();
// start clients // start clients

View File

@ -22,7 +22,8 @@ TEST(Network, SessionLeak) {
// initialize server // initialize server
TestData session_data; TestData session_data;
ServerT server(endpoint, session_data, -1, "Test", 2); ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", 2);
// start clients // start clients
int N = 50; int N = 50;

View File

@ -1,2 +1,5 @@
# telemetry test binaries # telemetry test binaries
add_subdirectory(telemetry) add_subdirectory(telemetry)
# ssl test binaries
add_subdirectory(ssl)

View File

@ -6,3 +6,11 @@
- server.py # server script - server.py # server script
- ../../../build_debug/tests/integration/telemetry/client # client binary - ../../../build_debug/tests/integration/telemetry/client # client binary
- ../../../build_debug/tests/manual/kvstore_console # kvstore console binary - ../../../build_debug/tests/manual/kvstore_console # kvstore console binary
- name: integration__ssl
cd: ssl
commands: ./runner.sh
infiles:
- runner.sh # runner script
- ../../../build_debug/tests/integration/ssl/tester # tester binary
enable_network: true

View File

@ -0,0 +1,6 @@
set(target_name memgraph__integration__ssl)
set(tester_target_name ${target_name}__tester)
add_executable(${tester_target_name} tester.cpp)
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
target_link_libraries(${tester_target_name} memgraph_lib kvstore_dummy_lib)

78
tests/integration/ssl/runner.sh Executable file
View File

@ -0,0 +1,78 @@
#!/bin/bash -e
pushd () { command pushd "$@" > /dev/null; }
popd () { command popd "$@" > /dev/null; }
die () { printf "\033[1;31m~~ Test failed! ~~\033[0m\n\n"; exit 1; }
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$DIR"
tmpdir=/tmp/memgraph_integration_ssl
if [ -d $tmpdir ]; then
rm -rf $tmpdir
fi
mkdir -p $tmpdir
cd $tmpdir
easyrsa="EasyRSA-3.0.4"
wget -nv http://deps.memgraph.io/$easyrsa.tgz
tar -xf $easyrsa.tgz
mv $easyrsa ca1
cp -r ca1 ca2
for i in ca1 ca2; do
pushd $i
./easyrsa --batch init-pki
./easyrsa --batch build-ca nopass
./easyrsa --batch build-server-full server nopass
./easyrsa --batch build-client-full client nopass
popd
done
binary_dir="$DIR/../../../build"
if [ ! -d $binary_dir ]; then
binary_dir="$DIR/../../../build_debug"
fi
set +e
echo
for i in ca1 ca2; do
for j in none ca1 ca2; do
for k in false true; do
printf "\033[1;36m~~ Server CA: $i; Client CA: $j; Verify peer: $k ~~\033[0m\n"
if [ "$j" == "none" ]; then
client_key=""
client_cert=""
else
client_key=$j/pki/private/client.key
client_cert=$j/pki/issued/client.crt
fi
$binary_dir/tests/integration/ssl/tester \
--server-key-file=$i/pki/private/server.key \
--server-cert-file=$i/pki/issued/server.crt \
--server-ca-file=$i/pki/ca.crt \
--server-verify-peer=$k \
--client-key-file=$client_key \
--client-cert-file=$client_cert
exitcode=$?
if [ "$i" == "$j" ]; then
[ $exitcode -eq 0 ] || die
else
if $k; then
[ $exitcode -ne 0 ] || die
else
[ $exitcode -eq 0 ] || die
fi
fi
printf "\033[1;32m~~ Test ok! ~~\033[0m\n\n"
done
done
done

View File

@ -0,0 +1,76 @@
#include <atomic>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/client.hpp"
#include "communication/server.hpp"
#include "utils/exceptions.hpp"
DEFINE_string(server_cert_file, "", "Server certificate file to use.");
DEFINE_string(server_key_file, "", "Server key file to use.");
DEFINE_string(server_ca_file, "", "Server CA file to use.");
DEFINE_bool(server_verify_peer, false, "Set to true to verify the peer.");
DEFINE_string(client_cert_file, "", "Client certificate file to use.");
DEFINE_string(client_key_file, "", "Client key file to use.");
const std::string message = "ssl echo test";
struct EchoData {};
class EchoSession {
public:
EchoSession(EchoData &, communication::InputStream &input_stream,
communication::OutputStream &output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < message.size()) return;
LOG(INFO) << "Server received message.";
if (!output_stream_.Write(input_stream_.data(), message.size())) {
throw utils::BasicException("Output stream write failed!");
}
LOG(INFO) << "Server sent message.";
input_stream_.Shift(message.size());
}
private:
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
};
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
// Initialize the communication stack.
communication::Init();
// Initialize the server.
EchoData echo_data;
communication::ServerContext server_context(
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
FLAGS_server_verify_peer);
communication::Server<EchoSession, EchoData> server(
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1);
// Initialize the client.
communication::ClientContext client_context(FLAGS_client_key_file,
FLAGS_client_cert_file);
communication::Client client(&client_context);
// Connect to the server.
CHECK(client.Connect(server.endpoint())) << "Couldn't connect to server!";
// Perform echo.
CHECK(client.Write(message)) << "Client couldn't send message!";
LOG(INFO) << "Client sent message.";
CHECK(client.Read(message.size())) << "Client couldn't receive message!";
LOG(INFO) << "Client received message.";
CHECK(std::string(reinterpret_cast<const char *>(client.GetData()),
message.size()) == message)
<< "Received message isn't equal to sent message!";
return 0;
}

View File

@ -10,6 +10,7 @@
#include "io/network/endpoint.hpp" #include "io/network/endpoint.hpp"
using EndpointT = io::network::Endpoint; using EndpointT = io::network::Endpoint;
using ContextT = communication::ClientContext;
using ClientT = communication::bolt::Client; using ClientT = communication::bolt::Client;
using QueryDataT = communication::bolt::QueryData; using QueryDataT = communication::bolt::QueryData;
using communication::bolt::DecodedValue; using communication::bolt::DecodedValue;
@ -18,23 +19,23 @@ class BoltClient {
public: public:
BoltClient(const std::string &address, uint16_t port, BoltClient(const std::string &address, uint16_t port,
const std::string &username, const std::string &password, const std::string &username, const std::string &password,
const std::string & = "") { const std::string & = "", bool use_ssl = false)
: context_(use_ssl), client_(context_) {
EndpointT endpoint(address, port); EndpointT endpoint(address, port);
client_ = std::make_unique<ClientT>();
if (!client_->Connect(endpoint, username, password)) { if (!client_.Connect(endpoint, username, password)) {
LOG(FATAL) << "Could not connect to: " << endpoint; LOG(FATAL) << "Could not connect to: " << endpoint;
} }
} }
QueryDataT Execute(const std::string &query, QueryDataT Execute(const std::string &query,
const std::map<std::string, DecodedValue> &parameters) { const std::map<std::string, DecodedValue> &parameters) {
return client_->Execute(query, parameters); return client_.Execute(query, parameters);
} }
void Close() { client_->Close(); } void Close() { client_.Close(); }
private: private:
std::unique_ptr<ClientT> client_; ContextT context_;
ClientT client_;
}; };

View File

@ -331,11 +331,14 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
communication::Init();
stats::InitStatsLogging( stats::InitStatsLogging(
fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario)); fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario));
Endpoint endpoint(FLAGS_address, FLAGS_port); Endpoint endpoint(FLAGS_address, FLAGS_port);
Client client; ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint; LOG(FATAL) << "Couldn't connect to " << endpoint;
} }

View File

@ -11,6 +11,7 @@
#include "utils/exceptions.hpp" #include "utils/exceptions.hpp"
#include "utils/timer.hpp" #include "utils/timer.hpp"
using communication::ClientContext;
using communication::bolt::Client; using communication::bolt::Client;
using communication::bolt::DecodedValue; using communication::bolt::DecodedValue;
using io::network::Endpoint; using io::network::Endpoint;

View File

@ -16,6 +16,7 @@ DEFINE_int32(num_workers, 1, "Number of workers");
DEFINE_string(output, "", "Output file"); DEFINE_string(output, "", "Output file");
DEFINE_string(username, "", "Username for the database"); DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database"); DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_int32(duration, 30, "Number of seconds to execute benchmark"); DEFINE_int32(duration, 30, "Number of seconds to execute benchmark");
DEFINE_string(group, "unknown", "Test group name"); DEFINE_string(group, "unknown", "Test group name");
@ -97,7 +98,8 @@ class TestClient {
std::thread runner_thread_; std::thread runner_thread_;
private: private:
Client client_; communication::ClientContext context_{FLAGS_use_ssl};
Client client_{&context_};
}; };
void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) { void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) {

View File

@ -271,12 +271,15 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
communication::Init();
nlohmann::json config; nlohmann::json config;
std::cin >> config; std::cin >> config;
auto independent_nodes_ids = [&] { auto independent_nodes_ids = [&] {
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port); Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
Client client; ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint; LOG(FATAL) << "Couldn't connect to " << endpoint;
} }

View File

@ -19,6 +19,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port"); DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database"); DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database"); DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
using communication::bolt::DecodedValue; using communication::bolt::DecodedValue;
@ -58,7 +59,8 @@ void ExecuteQueries(const std::vector<std::string> &queries,
for (int i = 0; i < FLAGS_num_workers; ++i) { for (int i = 0; i < FLAGS_num_workers; ++i) {
threads.push_back(std::thread([&]() { threads.push_back(std::thread([&]() {
Endpoint endpoint(FLAGS_address, FLAGS_port); Endpoint endpoint(FLAGS_address, FLAGS_port);
Client client; ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint; LOG(FATAL) << "Couldn't connect to " << endpoint;
} }
@ -100,6 +102,8 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
communication::Init();
std::ifstream ifile; std::ifstream ifile;
std::istream *istream{&std::cin}; std::istream *istream{&std::cin};

View File

@ -71,6 +71,12 @@ target_link_libraries(${test_prefix}sl_position_and_count memgraph_lib kvstore_d
add_manual_test(stripped_timing.cpp) add_manual_test(stripped_timing.cpp)
target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib) target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib)
add_manual_test(ssl_client.cpp)
target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib)
add_manual_test(ssl_server.cpp)
target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib)
add_manual_test(xorshift.cpp) add_manual_test(xorshift.cpp)
target_link_libraries(${test_prefix}xorshift mg-utils) target_link_libraries(${test_prefix}xorshift mg-utils)

View File

@ -10,15 +10,20 @@ DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port"); DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database"); DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database"); DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
int main(int argc, char **argv) { int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
communication::Init();
// TODO: handle endpoint exception // TODO: handle endpoint exception
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port); FLAGS_port);
communication::bolt::Client client;
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1; if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;

View File

@ -0,0 +1,65 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/client.hpp"
#include "io/network/endpoint.hpp"
#include "utils/timer.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 54321, "Server port");
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
bool EchoMessage(communication::Client &client, const std::string &data) {
uint16_t size = data.size();
if (!client.Write(reinterpret_cast<const uint8_t *>(&size), sizeof(size))) {
LOG(WARNING) << "Couldn't send data size!";
return false;
}
if (!client.Write(data)) {
LOG(WARNING) << "Couldn't send data!";
return false;
}
client.ClearData();
if (!client.Read(size)) {
LOG(WARNING) << "Couldn't receive data!";
return false;
}
if (std::string(reinterpret_cast<const char *>(client.GetData()), size) !=
data) {
LOG(WARNING) << "Received data isn't equal to sent data!";
return false;
}
return true;
}
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
communication::ClientContext context(FLAGS_key_file, FLAGS_cert_file);
communication::Client client(&context);
if (!client.Connect(endpoint)) return 1;
bool success = true;
while (true) {
std::string s;
std::getline(std::cin, s);
if (s == "") break;
if (!EchoMessage(client, s)) {
success = false;
break;
}
}
// Send server shutdown signal. The call will fail, we don't care.
EchoMessage(client, "");
return success ? 0 : 1;
}

View File

@ -0,0 +1,73 @@
#include <atomic>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/server.hpp"
#include "utils/exceptions.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 54321, "Server port");
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
DEFINE_string(ca_file, "", "CA file to use.");
DEFINE_bool(verify_peer, false, "Set to true to verify the peer.");
struct EchoData {
std::atomic<bool> alive{true};
};
class EchoSession {
public:
EchoSession(EchoData &data, communication::InputStream &input_stream,
communication::OutputStream &output_stream)
: data_(data),
input_stream_(input_stream),
output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < 2) return;
const uint8_t *data = input_stream_.data();
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data());
input_stream_.Resize(size + 2);
if (input_stream_.size() < size + 2) return;
if (size == 0) {
LOG(INFO) << "Server received EOF message";
data_.alive.store(false);
return;
}
LOG(INFO) << "Server received '"
<< std::string(reinterpret_cast<const char *>(data + 2), size)
<< "'";
if (!output_stream_.Write(data + 2, size)) {
throw utils::BasicException("Output stream write failed!");
}
input_stream_.Shift(size + 2);
}
private:
EchoData &data_;
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
};
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
// Initialize the server.
EchoData echo_data;
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
FLAGS_ca_file, FLAGS_verify_peer);
communication::Server<EchoSession, EchoData> server(endpoint, echo_data,
&context, -1, "SSL", 1);
while (echo_data.alive) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
return 0;
}

View File

@ -1 +1,2 @@
.long_running_stats .long_running_stats
*.pem

View File

@ -10,6 +10,10 @@
commands: TIMEOUT=600 ./continuous_integration --properties-on-disk commands: TIMEOUT=600 ./continuous_integration --properties-on-disk
infiles: *STRESS_INFILES infiles: *STRESS_INFILES
- name: stress_ssl
commands: TIMEOUT=600 ./continuous_integration --use-ssl
infiles: *STRESS_INFILES
- name: stress_large - name: stress_large
project: release project: release
commands: TIMEOUT=43200 ./continuous_integration --large-dataset commands: TIMEOUT=43200 ./continuous_integration --large-dataset

View File

@ -146,14 +146,14 @@ def connection_argument_parser():
''' '''
parser = ArgumentParser(description=__doc__) parser = ArgumentParser(description=__doc__)
parser.add_argument('--endpoint', type=str, default='localhost:7687', parser.add_argument('--endpoint', type=str, default='127.0.0.1:7687',
help='DBMS instance endpoint. ' help='DBMS instance endpoint. '
'Bolt protocol is the only option.') 'Bolt protocol is the only option.')
parser.add_argument('--username', type=str, default='neo4j', parser.add_argument('--username', type=str, default='neo4j',
help='DBMS instance username.') help='DBMS instance username.')
parser.add_argument('--password', type=int, default='1234', parser.add_argument('--password', type=int, default='1234',
help='DBMS instance password.') help='DBMS instance password.')
parser.add_argument('--ssl-enabled', action='store_true', parser.add_argument('--use-ssl', action='store_true',
help="Is SSL enabled?") help="Is SSL enabled?")
return parser return parser
@ -163,7 +163,7 @@ def bolt_session(url, auth, ssl=False):
''' '''
with wrapper around Bolt session. with wrapper around Bolt session.
:param url: str, e.g. "bolt://localhost:7687" :param url: str, e.g. "bolt://127.0.0.1:7687"
:param auth: auth method, goes directly to the Bolt driver constructor :param auth: auth method, goes directly to the Bolt driver constructor
:param ssl: bool, is ssl enabled :param ssl: bool, is ssl enabled
''' '''
@ -183,14 +183,15 @@ def argument_session(args):
:return: Bolt session context manager based on program arguments :return: Bolt session context manager based on program arguments
''' '''
return bolt_session('bolt://' + args.endpoint, return bolt_session('bolt://' + args.endpoint,
(args.username, str(args.password))) (args.username, str(args.password)),
args.use_ssl)
def argument_driver(args, ssl=False): def argument_driver(args):
return GraphDatabase.driver( return GraphDatabase.driver(
'bolt://' + args.endpoint, 'bolt://' + args.endpoint,
auth=(args.username, str(args.password)), auth=(args.username, str(args.password)),
encrypted=ssl) encrypted=args.use_ssl)
# This class is used to create and cache sessions. Session is cached by args # This class is used to create and cache sessions. Session is cached by args
# used to create it and process' pid in which it was created. This makes it easy # used to create it and process' pid in which it was created. This makes it easy

View File

@ -73,6 +73,8 @@ BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build") BUILD_DIR = os.path.join(BASE_DIR, "build")
CONFIG_DIR = os.path.join(BASE_DIR, "config") CONFIG_DIR = os.path.join(BASE_DIR, "config")
MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements") MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements")
KEY_FILE = os.path.join(SCRIPT_DIR, ".key.pem")
CERT_FILE = os.path.join(SCRIPT_DIR, ".cert.pem")
# long running stats file # long running stats file
STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats") STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats")
@ -132,6 +134,8 @@ parser.add_argument("--python", default = os.path.join(SCRIPT_DIR,
"ve3", "bin", "python3"), type = str) "ve3", "bin", "python3"), type = str)
parser.add_argument("--large-dataset", action = "store_const", parser.add_argument("--large-dataset", action = "store_const",
const = True, default = False) const = True, default = False)
parser.add_argument("--use-ssl", action = "store_const",
const = True, default = False)
parser.add_argument("--verbose", action = "store_const", parser.add_argument("--verbose", action = "store_const",
const = True, default = False) const = True, default = False)
args = parser.parse_args() args = parser.parse_args()
@ -140,6 +144,14 @@ args = parser.parse_args()
if not os.path.exists(args.memgraph): if not os.path.exists(args.memgraph):
args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph") args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph")
# generate temporary SSL certs
if args.use_ssl:
# https://unix.stackexchange.com/questions/104171/create-ssl-certificate-non-interactively
subj = "/C=HR/ST=Zagreb/L=Zagreb/O=Memgraph/CN=db.memgraph.com"
subprocess.run(["openssl", "req", "-new", "-newkey", "rsa:4096",
"-days", "365", "-nodes", "-x509", "-subj", subj,
"-keyout", KEY_FILE, "-out", CERT_FILE], check=True)
# start memgraph # start memgraph
cwd = os.path.dirname(args.memgraph) cwd = os.path.dirname(args.memgraph)
cmd = [args.memgraph, "--num-workers=" + str(THREADS)] cmd = [args.memgraph, "--num-workers=" + str(THREADS)]
@ -151,6 +163,8 @@ if args.durability_directory:
cmd += ["--durability-directory", args.durability_directory] cmd += ["--durability-directory", args.durability_directory]
if args.properties_on_disk: if args.properties_on_disk:
cmd += ["--properties-on-disk", "id,x"] cmd += ["--properties-on-disk", "id,x"]
if args.use_ssl:
cmd += ["--cert-file", CERT_FILE, "--key-file", KEY_FILE]
proc_mg = subprocess.Popen(cmd, cwd = cwd, proc_mg = subprocess.Popen(cmd, cwd = cwd,
env = {"MEMGRAPH_CONFIG": args.config}) env = {"MEMGRAPH_CONFIG": args.config})
time.sleep(1.0) time.sleep(1.0)
@ -167,6 +181,8 @@ def cleanup():
runtimes = {} runtimes = {}
dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET
for test in dataset: for test in dataset:
if args.use_ssl:
test["options"] += ["--use-ssl"]
runtime = run_test(args, **test) runtime = run_test(args, **test)
runtimes[os.path.splitext(test["test"])[0]] = runtime runtimes[os.path.splitext(test["test"])[0]] = runtime
@ -176,6 +192,11 @@ ret_mg = proc_mg.wait()
if ret_mg != 0: if ret_mg != 0:
raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg)) raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg))
# cleanup certificates
if args.use_ssl:
os.remove(KEY_FILE)
os.remove(CERT_FILE)
# measurements # measurements
measurements = "" measurements = ""
for key, value in runtimes.items(): for key, value in runtimes.items():

View File

@ -8,6 +8,7 @@
#include "utils/timer.hpp" #include "utils/timer.hpp"
using EndpointT = io::network::Endpoint; using EndpointT = io::network::Endpoint;
using ClientContextT = communication::ClientContext;
using ClientT = communication::bolt::Client; using ClientT = communication::bolt::Client;
using DecodedValueT = communication::bolt::DecodedValue; using DecodedValueT = communication::bolt::DecodedValue;
using QueryDataT = communication::bolt::QueryData; using QueryDataT = communication::bolt::QueryData;
@ -17,6 +18,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port"); DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database"); DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database"); DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_int32(vertex_count, 0, DEFINE_int32(vertex_count, 0,
"The average number of vertices in the graph per worker"); "The average number of vertices in the graph per worker");
@ -51,7 +53,7 @@ class GraphSession {
} }
EndpointT endpoint(FLAGS_address, FLAGS_port); EndpointT endpoint(FLAGS_address, FLAGS_port);
client_ = std::make_unique<ClientT>(); client_ = std::make_unique<ClientT>(&context_);
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) { if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to server!"); throw utils::BasicException("Couldn't connect to server!");
@ -60,6 +62,7 @@ class GraphSession {
private: private:
uint64_t id_; uint64_t id_;
ClientContextT context_{FLAGS_use_ssl};
std::unique_ptr<ClientT> client_; std::unique_ptr<ClientT> client_;
std::set<uint64_t> vertices_; std::set<uint64_t> vertices_;
@ -362,6 +365,8 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
communication::Init();
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!"; CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!"; CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";
@ -369,7 +374,8 @@ int main(int argc, char **argv) {
// create client // create client
EndpointT endpoint(FLAGS_address, FLAGS_port); EndpointT endpoint(FLAGS_address, FLAGS_port);
ClientT client; ClientContextT context(FLAGS_use_ssl);
ClientT client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) { if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to server!"); throw utils::BasicException("Couldn't connect to server!");
} }

View File

@ -52,8 +52,9 @@ bool QueryServer(io::network::Socket &socket) {
TEST(NetworkTimeouts, InactiveSession) { TEST(NetworkTimeouts, InactiveSession) {
// Instantiate the server and set the session timeout to 2 seconds. // Instantiate the server and set the session timeout to 2 seconds.
TestData test_data; TestData test_data;
communication::ServerContext context;
communication::Server<TestSession, TestData> server{ communication::Server<TestSession, TestData> server{
{"127.0.0.1", 0}, test_data, 2, "Test", 1}; {"127.0.0.1", 0}, test_data, &context, 2, "Test", 1};
// Create the client and connect to the server. // Create the client and connect to the server.
io::network::Socket client; io::network::Socket client;

94
tests/unit/socket.cpp Normal file
View File

@ -0,0 +1,94 @@
#include <chrono>
#include <csignal>
#include <thread>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "io/network/socket.hpp"
#include "utils/timer.hpp"
TEST(Socket, WaitForReadyRead) {
io::network::Socket server;
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
ASSERT_TRUE(server.Listen(1024));
std::thread thread([&server] {
io::network::Socket client;
ASSERT_TRUE(client.Connect(server.endpoint()));
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
ASSERT_TRUE(client.Write("test"));
});
uint8_t buff[100];
auto client = server.Accept();
ASSERT_TRUE(client);
client->SetNonBlocking();
ASSERT_EQ(client->Read(buff, sizeof(buff)), -1);
ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR);
utils::Timer timer;
ASSERT_TRUE(client->WaitForReadyRead());
ASSERT_GT(timer.Elapsed().count(), 1.0);
ASSERT_GT(client->Read(buff, sizeof(buff)), 0);
thread.join();
}
TEST(Socket, WaitForReadyWrite) {
io::network::Socket server;
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
ASSERT_TRUE(server.Listen(1024));
std::thread thread([&server] {
uint8_t buff[10000];
io::network::Socket client;
ASSERT_TRUE(client.Connect(server.endpoint()));
client.SetNonBlocking();
// Decrease the TCP read buffer.
int len = 1024;
ASSERT_EQ(setsockopt(client.fd(), SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)),
0);
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
while (true) {
int ret = client.Read(buff, sizeof(buff));
if (ret == -1 && errno != EAGAIN && errno != EWOULDBLOCK &&
errno != EINTR) {
std::raise(SIGPIPE);
} else if (ret == 0) {
break;
}
}
});
auto client = server.Accept();
ASSERT_TRUE(client);
client->SetNonBlocking();
// Decrease the TCP write buffer.
int len = 1024;
ASSERT_EQ(setsockopt(client->fd(), SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)),
0);
utils::Timer timer;
for (int i = 0; i < 1000000; ++i) {
ASSERT_TRUE(client->Write("test"));
}
ASSERT_GT(timer.Elapsed().count(), 1.0);
client->Close();
thread.join();
}
int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}