Implement SSL support for servers and clients

Summary:
This diff implements OpenSSL support in the network stack.
Currently SSL support is only enabled for Bolt connections,
support for RPC connections will be added in another diff.

Reviewers: buda, teon.banek

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1328
This commit is contained in:
Matej Ferencevic 2018-06-20 17:44:47 +02:00
parent 44821a918c
commit 1d448d40ca
55 changed files with 1400 additions and 177 deletions

View File

@ -10,6 +10,7 @@
* Static vertices/edges id generators exposed through the Id Cypher function.
* Properties on disk added.
* Telemetry added.
* SSL support added.
* Add `toString` function to openCypher
### Bug Fixes and Other Changes

View File

@ -140,6 +140,9 @@ endif()
set(Boost_USE_STATIC_LIBS ON)
find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization)
# OpenSSL
find_package(OpenSSL REQUIRED)
set(libs_dir ${CMAKE_SOURCE_DIR}/libs)
add_subdirectory(libs EXCLUDE_FROM_ALL)
@ -320,6 +323,8 @@ set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY}
Contains Memgraph, the graph database. It aims to deliver developers the
speed, simplicity and scale required to build the next generation of
applications driver by real-time connected data.")
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
set(CPACK_DEBIAN_memgraph_PACKAGE_DEPENDS "openssl (>= 1.1.0)")
# RPM specific
set(CPACK_RPM_PACKAGE_URL https://memgraph.com)
@ -335,6 +340,8 @@ set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_SOURCE_DIR}/release/rpm/memgraph.spe
set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database.
It aims to deliver developers the speed, simplicity and scale required to build
the next generation of applications driver by real-time connected data.")
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
set(CPACK_RPM_memgraph_PACKAGE_REQUIRES "openssl >= 1.0.0")
# All variables must be set before including.
include(CPack)

View File

@ -16,6 +16,12 @@
# Port the server should listen on.
--port=7687
# Path to a SSL certificate file that should be used.
--cert-file=/etc/memgraph/ssl/cert.pem
# Path to a SSL key file that should be used.
--key-file=/etc/memgraph/ssl/key.pem
# Number of workers used by the Memgraph server. By default, this will be the
# number of processing units available on the machine.
# --num-workers=8

View File

@ -13,11 +13,9 @@ from neo4j.v1 import GraphDatabase, basic_auth
# Initialize and configure the driver.
# * provide the correct URL where Memgraph is reachable;
# * use an empty user name and password, and
# * disable encryption (not supported).
# * use an empty user name and password.
driver = GraphDatabase.driver("bolt://localhost:7687",
auth=basic_auth("", ""),
encrypted=False)
auth=basic_auth("", ""))
# Start a session in which queries are executed.
session = driver.session()
@ -51,9 +49,7 @@ The details about Java driver can be found
[on GitHub](https://github.com/neo4j/neo4j-java-driver).
The example below is equivalent to Python example. Major difference is that
`Config` object has to be created before the driver construction. Encryption
has to be disabled by calling `withoutEncryption` method against the `Config`
builder.
`Config` object has to be created before the driver construction.
```java
import org.neo4j.driver.v1.*;
@ -64,7 +60,7 @@ import java.util.*;
public class JavaQuickStart {
public static void main(String[] args) {
// Initialize driver.
Config config = Config.build().withoutEncryption().toConfig();
Config config = Config.build().toConfig();
Driver driver = GraphDatabase.driver("bolt://localhost:7687",
AuthTokens.basic("",""),
config);
@ -93,9 +89,7 @@ public class JavaQuickStart {
The details about Javascript driver can be found
[on GitHub](https://github.com/neo4j/neo4j-javascript-driver).
The Javascript example below is equivalent to Python and Java examples. SSL
can be disabled by passing `{encrypted: 'ENCRYPTION_OFF'}` during the driver
construction.
The Javascript example below is equivalent to Python and Java examples.
Here is an example related to `Node.js`. Memgraph doesn't have integrated
support for `WebSocket` which is required during the execution in any web
@ -109,8 +103,7 @@ proxy port.
```javascript
var neo4j = require('neo4j-driver').v1;
var driver = neo4j.driver("bolt://localhost:7687",
neo4j.auth.basic("neo4j", "1234"),
{encrypted: 'ENCRYPTION_OFF'});
neo4j.auth.basic("neo4j", "1234"));
var session = driver.session();
function die() {
@ -146,8 +139,7 @@ run_query("MATCH (n) DETACH DELETE n", function (result) {
The C# driver is hosted
[on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below
performs the same work as all of the previous examples. Encryption is disabled
by setting `EncryptionLevel.NONE` on the `Config`.
performs the same work as all of the previous examples.
```csh
using System;
@ -158,7 +150,6 @@ public class Basic {
public static void Main(string[] args) {
// Initialize the driver.
var config = Config.DefaultConfig;
config.EncryptionLevel = EncryptionLevel.None;
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config))
using(var session = driver.Session())
{
@ -176,6 +167,18 @@ public class Basic {
}
```
### Secure Sockets Layer (SSL)
Secure connections are supported and enabled by default. The server initially
ships with a self-signed testing certificate. The certificate can be replaced
by editing the following parameters in `/etc/memgraph/memgraph.conf`:
```
--cert-file=/path/to/ssl/certificate.pem
--key-file=/path/to/ssl/privatekey.pem
```
To disable SSL support and use insecure connections to the database you should
set both parameters (`--cert-file` and `--key-file`) to empty values.
### Limitations
Memgraph is currently in early stage, and has a number of limitations we plan
@ -186,9 +189,3 @@ to remove in future versions.
Memgraph is currently single-user only. There is no way to control user
privileges. The default user has read and write privileges over the whole
database.
#### Secure Sockets Layer (SSL)
Secure connections are not supported. For this reason each client
driver needs to be configured not to use encryption. Consult driver-specific
guides for details.

View File

@ -183,7 +183,7 @@ After installing `neo4j-client`, connect to the running Memgraph instance by
issuing the following shell command.
```bash
neo4j-client --insecure -u "" -p "" localhost 7687
neo4j-client -u "" -p "" localhost 7687
```
After the client has started it should present a command prompt similar to:
@ -191,7 +191,7 @@ After the client has started it should present a command prompt similar to:
```bash
neo4j-client 2.1.3
Enter `:help` for usage hints.
Connected to 'neo4j://@localhost:7687' (insecure)
Connected to 'neo4j://@localhost:7687'
neo4j>
```

1
init
View File

@ -8,6 +8,7 @@ required_pkgs=(git arcanist # source code control
curl wget # for downloading libs
uuid-dev default-jre-headless # required by antlr
libreadline-dev # for memgraph console
libssl-dev
libboost-iostreams-dev
libboost-serialization-dev
python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests

View File

@ -29,6 +29,16 @@ case "$1" in
chmod 750 /var/log/memgraph || exit 1
# Make examples directory immutable (optional)
chattr +i -R /usr/share/memgraph/examples || true
# Generate SSL certificates
if [ ! -d /etc/memgraph/ssl ]; then
mkdir /etc/memgraph/ssl || exit 1
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
chmod 400 /etc/memgraph/ssl/* || exit 1
fi
;;
abort-upgrade|abort-remove|abort-deconfigure)

View File

@ -71,6 +71,16 @@ chown memgraph:adm /var/log/memgraph || exit 1
chmod 750 /var/log/memgraph || exit 1
# Make examples directory immutable (optional)
chattr +i -R /usr/share/memgraph/examples || true
# Generate SSL certificates
if [ ! -d /etc/memgraph/ssl ]; then
mkdir /etc/memgraph/ssl || exit 1
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
chmod 400 /etc/memgraph/ssl/* || exit 1
fi
@RPM_SYMLINK_POSTINSTALL@
@CPACK_RPM_SPEC_POSTINSTALL@

View File

@ -8,6 +8,10 @@ add_subdirectory(telemetry)
# all memgraph src files
set(memgraph_src_files
communication/buffer.cpp
communication/client.cpp
communication/context.cpp
communication/helpers.cpp
communication/init.cpp
communication/bolt/v1/decoder/decoded_value.cpp
communication/rpc/client.cpp
communication/rpc/protocol.cpp
@ -189,6 +193,7 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
# memgraph_lib depend on these libraries
set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools
antlr_opencypher_parser_lib dl glog gflags capnp kj
${OPENSSL_LIBRARIES}
${Boost_IOSTREAMS_LIBRARY_RELEASE}
${Boost_SERIALIZATION_LIBRARY_RELEASE}
mg-utils mg-io)
@ -206,6 +211,7 @@ endif()
# STATIC library used by memgraph executables
add_library(memgraph_lib STATIC ${memgraph_src_files})
target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS})
target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR})
add_dependencies(memgraph_lib generate_opencypher_parser)
add_dependencies(memgraph_lib generate_lcp)
add_dependencies(memgraph_lib generate_capnp)

View File

@ -32,9 +32,9 @@ struct QueryData {
std::map<std::string, DecodedValue> metadata;
};
class Client {
class Client final {
public:
Client() {}
explicit Client(communication::ClientContext *context) : client_(context) {}
Client(const Client &) = delete;
Client(Client &&) = delete;

View File

@ -20,7 +20,7 @@ namespace communication {
* stack where all execution when it is being done is being done on a single
* thread.
*/
class Buffer {
class Buffer final {
private:
// Initial capacity of the internal buffer.
const size_t kBufferInitialSize = 65536;
@ -28,6 +28,11 @@ class Buffer {
public:
Buffer();
Buffer(const Buffer &) = delete;
Buffer(Buffer &&) = delete;
Buffer &operator=(const Buffer &) = delete;
Buffer &operator=(Buffer &&) = delete;
/**
* This class provides all functions from the buffer that are needed to allow
* reading data from the buffer.
@ -36,6 +41,11 @@ class Buffer {
public:
ReadEnd(Buffer &buffer);
ReadEnd(const ReadEnd &) = delete;
ReadEnd(ReadEnd &&) = delete;
ReadEnd &operator=(const ReadEnd &) = delete;
ReadEnd &operator=(ReadEnd &&) = delete;
uint8_t *data();
size_t size() const;
@ -58,6 +68,11 @@ class Buffer {
public:
WriteEnd(Buffer &buffer);
WriteEnd(const WriteEnd &) = delete;
WriteEnd(WriteEnd &&) = delete;
WriteEnd &operator=(const WriteEnd &) = delete;
WriteEnd &operator=(WriteEnd &&) = delete;
io::network::StreamBuffer Allocate();
void Written(size_t len);

View File

@ -0,0 +1,232 @@
#include <glog/logging.h>
#include "communication/client.hpp"
#include "communication/helpers.hpp"
namespace communication {
Client::Client(ClientContext *context) : context_(context) {}
Client::~Client() {
Close();
ReleaseSslObjects();
}
bool Client::Connect(const io::network::Endpoint &endpoint) {
// Try to establish a socket connection.
if (!socket_.Connect(endpoint)) return false;
// Enable TCP keep alive for all connections.
// Because we manually always set the `have_more` flag to the socket
// `Write` call we can disable the Nagle algorithm because we know that we
// are always sending optimal packets. Even if we don't send optimal
// packets, there will be no delay between packets and throughput won't
// suffer.
socket_.SetKeepAlive();
socket_.SetNoDelay();
if (context_->use_ssl()) {
// Release leftover SSL objects.
ReleaseSslObjects();
// Create a new SSL object that will be used for SSL communication.
ssl_ = SSL_new(context_->context());
if (ssl_ == nullptr) {
LOG(WARNING) << "Couldn't create client SSL object!";
socket_.Close();
return false;
}
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
// it doesn't need to close the socket when destructing all objects (we
// handle that in our socket destructor).
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
if (bio_ == nullptr) {
LOG(WARNING) << "Couldn't create client BIO object!";
socket_.Close();
return false;
}
// Connect the BIO object to the SSL object so that OpenSSL knows which
// stream it should use for communication. We use the same object for both
// the read and write end. This function cannot fail.
SSL_set_bio(ssl_, bio_, bio_);
// Clear all leftover errors.
ERR_clear_error();
// Perform the TLS handshake.
auto ret = SSL_connect(ssl_);
if (ret != 1) {
LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
socket_.Close();
return false;
}
}
return true;
}
bool Client::ErrorStatus() { return socket_.ErrorStatus(); }
void Client::Shutdown() { socket_.Shutdown(); }
void Client::Close() {
if (ssl_) {
// Perform an unidirectional SSL shutdown. That just means that we send
// the shutdown message and don't wait or care for the result.
SSL_shutdown(ssl_);
}
socket_.Close();
}
bool Client::Read(size_t len) {
size_t received = 0;
buffer_.write_end().Resize(buffer_.read_end().size() + len);
while (received < len) {
auto buff = buffer_.write_end().Allocate();
if (ssl_) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Read encrypted data from the socket using OpenSSL.
auto got = SSL_read(ssl_, buff.data, len - received);
// Handle errors that might have occurred.
if (got < 0) {
auto err = SSL_get_error(ssl_, got);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL want's to read more data from the socket. We wait for
// more data to be ready and retry the call.
socket_.WaitForReadyRead();
continue;
} else if (err == SSL_ERROR_WANT_WRITE) {
// The OpenSSL library probably wants to perform some kind of
// handshake so we wait for the socket to become ready for a write
// and call the read again.
socket_.WaitForReadyWrite();
continue;
} else {
// This is a fatal error.
LOG(WARNING) << "Received an unexpected SSL error: " << err;
return false;
}
} else if (got == 0) {
// The server closed the connection.
return false;
}
// Notify the buffer that it has new data.
buffer_.write_end().Written(got);
received += got;
} else {
// Read raw data from the socket.
auto got = socket_.Read(buff.data, len - received);
if (got <= 0) {
// If `read` returns 0 the server has closed the connection. If `read`
// returns -1 all of the errors that could be found in `errno` are
// fatal errors (because we are using a blocking socket) so return a
// read failure.
return false;
}
// Notify the buffer that it has new data.
buffer_.write_end().Written(got);
received += got;
}
}
return true;
}
uint8_t *Client::GetData() { return buffer_.read_end().data(); }
size_t Client::GetDataSize() { return buffer_.read_end().size(); }
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); }
void Client::ClearData() { buffer_.read_end().Clear(); }
bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
if (ssl_) {
// `SSL_write` has the interface of a normal `write` call. Because of that
// we need to ensure that all data is written to the socket manually.
while (len > 0) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Write data to the socket using OpenSSL.
auto written = SSL_write(ssl_, data, len);
if (written < 0) {
auto err = SSL_get_error(ssl_, written);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL wants to perform some kind of handshake, we need to
// ensure that there is data available for the next call to
// `SSL_write`.
socket_.WaitForReadyRead();
} else if (err == SSL_ERROR_WANT_WRITE) {
// The socket probably returned WOULDBLOCK and we need to wait for
// the output buffers to clear and reattempt the send.
socket_.WaitForReadyWrite();
} else {
// This is a fatal error.
return false;
}
} else if (written == 0) {
// The client closed the connection.
return false;
} else {
len -= written;
data += written;
}
}
return true;
} else {
return socket_.Write(data, len, have_more);
}
}
bool Client::Write(const std::string &str, bool have_more) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
have_more);
}
const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); }
void Client::ReleaseSslObjects() {
// If we are using SSL we need to free the allocated objects. Here we only
// free the SSL object because the `SSL_free` function also automatically
// frees the BIO object.
if (ssl_) {
SSL_free(ssl_);
ssl_ = nullptr;
bio_ = nullptr;
}
}
ClientInputStream::ClientInputStream(Client &client) : client_(client) {}
uint8_t *ClientInputStream::data() { return client_.GetData(); }
size_t ClientInputStream::size() const { return client_.GetDataSize(); }
void ClientInputStream::Shift(size_t len) { client_.ShiftData(len); }
void ClientInputStream::Clear() { client_.ClearData(); }
ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {}
bool ClientOutputStream::Write(const uint8_t *data, size_t len,
bool have_more) {
return client_.Write(data, len, have_more);
}
bool ClientOutputStream::Write(const std::string &str, bool have_more) {
return client_.Write(str, have_more);
}
} // namespace communication

View File

@ -1,6 +1,12 @@
#pragma once
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "communication/buffer.hpp"
#include "communication/context.hpp"
#include "communication/init.hpp"
#include "io/network/endpoint.hpp"
#include "io/network/socket.hpp"
@ -10,106 +16,115 @@ namespace communication {
* This class implements a generic network Client.
* It uses blocking sockets and provides an API that can be used to receive/send
* data over the network connection.
*
* NOTE: If you use this client you **must** call `communication::Init()` from
* the `main` function before using the client!
*/
class Client {
class Client final {
public:
explicit Client(ClientContext *context);
~Client();
Client(const Client &) = delete;
Client(Client &&) = delete;
Client &operator=(const Client &) = delete;
Client &operator=(Client &&) = delete;
/**
* This function connects to a remote server and returns whether the connect
* succeeded.
*/
bool Connect(const io::network::Endpoint &endpoint) {
if (!socket_.Connect(endpoint)) return false;
socket_.SetKeepAlive();
socket_.SetNoDelay();
return true;
}
bool Connect(const io::network::Endpoint &endpoint);
/**
* This function returns `true` if the socket is in an error state.
*/
bool ErrorStatus() { return socket_.ErrorStatus(); }
bool ErrorStatus();
/**
* This function shuts down the socket.
*/
void Shutdown() { socket_.Shutdown(); }
void Shutdown();
/**
* This function closes the socket.
*/
void Close() { socket_.Close(); }
void Close();
/**
* This function is used to receive `len` bytes from the socket and stores it
* in an internal buffer. It returns `true` if the read succeeded and `false`
* if it didn't.
*/
bool Read(size_t len) {
size_t received = 0;
buffer_.write_end().Resize(buffer_.read_end().size() + len);
while (received < len) {
auto buff = buffer_.write_end().Allocate();
int got = socket_.Read(buff.data, len - received);
if (got <= 0) return false;
buffer_.write_end().Written(got);
received += got;
}
return true;
}
bool Read(size_t len);
/**
* This function returns a pointer to the read data that is currently stored
* in the client.
*/
uint8_t *GetData() { return buffer_.read_end().data(); }
uint8_t *GetData();
/**
* This function returns the size of the read data that is currently stored in
* the client.
*/
size_t GetDataSize() { return buffer_.read_end().size(); }
size_t GetDataSize();
/**
* This function removes first `len` bytes from the data buffer.
*/
void ShiftData(size_t len) { buffer_.read_end().Shift(len); }
void ShiftData(size_t len);
/**
* This function clears the data buffer.
*/
void ClearData() { buffer_.read_end().Clear(); }
void ClearData();
// Write end
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
return socket_.Write(data, len, have_more);
}
bool Write(const std::string &str, bool have_more = false) {
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
have_more);
}
/**
* This function writes data to the socket.
* TODO (mferencevic): the `have_more` flag currently isn't supported when
* using OpenSSL
*/
bool Write(const uint8_t *data, size_t len, bool have_more = false);
const io::network::Endpoint &endpoint() { return socket_.endpoint(); }
/**
* This function writes data to the socket.
*/
bool Write(const std::string &str, bool have_more = false);
const io::network::Endpoint &endpoint();
private:
io::network::Socket socket_;
void ReleaseSslObjects();
io::network::Socket socket_;
Buffer buffer_;
ClientContext *context_;
SSL *ssl_{nullptr};
BIO *bio_{nullptr};
};
/**
* This class provides a stream-like input side object to the client.
*/
class ClientInputStream {
class ClientInputStream final {
public:
ClientInputStream(Client &client) : client_(client) {}
ClientInputStream(Client &client);
uint8_t *data() { return client_.GetData(); }
ClientInputStream(const ClientInputStream &) = delete;
ClientInputStream(ClientInputStream &&) = delete;
ClientInputStream &operator=(const ClientInputStream &) = delete;
ClientInputStream &operator=(ClientInputStream &&) = delete;
size_t size() const { return client_.GetDataSize(); }
uint8_t *data();
void Shift(size_t len) { client_.ShiftData(len); }
size_t size() const;
void Clear() { client_.ClearData(); }
void Shift(size_t len);
void Clear();
private:
Client &client_;
@ -118,16 +133,18 @@ class ClientInputStream {
/**
* This class provides a stream-like output side object to the client.
*/
class ClientOutputStream {
class ClientOutputStream final {
public:
ClientOutputStream(Client &client) : client_(client) {}
ClientOutputStream(Client &client);
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
return client_.Write(data, len, have_more);
}
bool Write(const std::string &str, bool have_more = false) {
return client_.Write(str, have_more);
}
ClientOutputStream(const ClientOutputStream &) = delete;
ClientOutputStream(ClientOutputStream &&) = delete;
ClientOutputStream &operator=(const ClientOutputStream &) = delete;
ClientOutputStream &operator=(ClientOutputStream &&) = delete;
bool Write(const uint8_t *data, size_t len, bool have_more = false);
bool Write(const std::string &str, bool have_more = false);
private:
Client &client_;

View File

@ -0,0 +1,80 @@
#include <glog/logging.h>
#include "communication/context.hpp"
namespace communication {
ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
if (use_ssl_) {
ctx_ = SSL_CTX_new(TLS_client_method());
CHECK(ctx_ != nullptr) << "Couldn't create client SSL_CTX object!";
// Disable legacy SSL support. Other options can be seen here:
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
}
}
ClientContext::ClientContext(const std::string &key_file,
const std::string &cert_file)
: ClientContext(true) {
if (key_file != "" && cert_file != "") {
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
SSL_FILETYPE_PEM) == 1)
<< "Couldn't load client certificate from file: " << cert_file;
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
SSL_FILETYPE_PEM) == 1)
<< "Couldn't load client private key from file: " << key_file;
}
}
SSL_CTX *ClientContext::context() { return ctx_; }
bool ClientContext::use_ssl() { return use_ssl_; }
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
ServerContext::ServerContext(const std::string &key_file,
const std::string &cert_file,
const std::string &ca_file, bool verify_peer)
: use_ssl_(true), ctx_(SSL_CTX_new(TLS_server_method())) {
// TODO (mferencevic): add support for encrypted private keys
// TODO (mferencevic): add certificate revocation list (CRL)
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
SSL_FILETYPE_PEM) == 1)
<< "Couldn't load server certificate from file: " << cert_file;
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) ==
1)
<< "Couldn't load server private key from file: " << key_file;
// Disable legacy SSL support. Other options can be seen here:
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
if (ca_file != "") {
// Load the certificate authority file.
CHECK(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1)
<< "Couldn't load certificate authority from file: " << ca_file;
if (verify_peer) {
// Add the CA to list of accepted CAs that is sent to the client.
STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str());
CHECK(ca_names != nullptr)
<< "Couldn't load certificate authority from file: " << ca_file;
// `ca_names` doesn' need to be free'd because we pass it to
// `SSL_CTX_set_client_CA_list`:
// https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html
SSL_CTX_set_client_CA_list(ctx_, ca_names);
// Enable verification of the client certificate.
SSL_CTX_set_verify(
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
}
}
}
SSL_CTX *ServerContext::context() { return ctx_; }
bool ServerContext::use_ssl() { return use_ssl_; }
} // namespace communication

View File

@ -0,0 +1,63 @@
#pragma once
#include <string>
#include <openssl/ssl.h>
namespace communication {
class ClientContext final {
public:
/**
* This constructor constructs a ClientContext that can either not use SSL
* (`use_ssl` is `false` by default), or it constructs a ClientContext that
* doesn't use a client certificate when `use_ssl` is set to `true`.
*/
explicit ClientContext(bool use_ssl = false);
/**
* This constructor constructs a ClientContext that uses SSL and uses the
* specific client private key and certificate combination. If the parameters
* `key_file` and `cert_file` are equal to "" then the constructor falls back
* to the above constructor that uses SSL without certificates.
*/
ClientContext(const std::string &key_file, const std::string &cert_file);
SSL_CTX *context();
bool use_ssl();
private:
bool use_ssl_;
SSL_CTX *ctx_;
};
class ServerContext final {
public:
/**
* This constructor constructs a ServerContext that doesn't use SSL.
*/
ServerContext();
/**
* This constructor constructs a ServerContext that uses SSL. The parameters
* `key_file` and `cert_file` can't be "" because when setting up a server it
* is mandatory to supply a private key and certificate. The parameter
* `ca_file` can be "" because SSL doesn't necessarily need to check that the
* client has a valid certificate. If you specify `verify_peer` to be `true`
* to check that the client certificate is valid, then you need to supply a
* valid `ca_file` as well.
*/
ServerContext(const std::string &key_file, const std::string &cert_file,
const std::string &ca_file = "", bool verify_peer = false);
SSL_CTX *context();
bool use_ssl();
private:
bool use_ssl_;
SSL_CTX *ctx_;
};
} // namespace communication

View File

@ -0,0 +1,13 @@
#include <openssl/err.h>
#include "communication/helpers.hpp"
namespace communication {
const std::string SslGetLastError() {
char buff[2048];
auto err = ERR_get_error();
ERR_error_string_n(err, buff, sizeof(buff));
return std::string(buff);
}
} // namespace communication

View File

@ -0,0 +1,12 @@
#pragma once
#include <string>
namespace communication {
/**
* This function reads and returns a string describing the last OpenSSL error.
*/
const std::string SslGetLastError();
} // namespace communication

View File

@ -0,0 +1,21 @@
#include <glog/logging.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "utils/signals.hpp"
namespace communication {
void Init() {
// Initialize the OpenSSL library.
SSL_library_init();
OpenSSL_add_ssl_algorithms();
SSL_load_error_strings();
ERR_load_crypto_strings();
// Ignore SIGPIPE.
CHECK(utils::SignalIgnore(utils::Signal::Pipe)) << "Couldn't ignore SIGPIPE!";
}
} // namespace communication

View File

@ -0,0 +1,16 @@
#pragma once
namespace communication {
/**
* Call this function in each `main` file that uses the Communication stack. It
* is used to initialize all libraries (primarily OpenSSL) and to fix some
* issues also related to OpenSSL (handling of SIGPIPE).
*
* Description of OpenSSL init can be seen here:
* https://wiki.openssl.org/index.php/Library_Initialization
*
* NOTE: This function must be called **exactly** once.
*/
void Init();
} // namespace communication

View File

@ -13,6 +13,7 @@
#include "communication/session.hpp"
#include "io/network/epoll.hpp"
#include "io/network/socket.hpp"
#include "utils/signals.hpp"
#include "utils/thread.hpp"
#include "utils/thread/sync.hpp"
@ -28,7 +29,7 @@ namespace communication {
* expired.
*/
template <class TSession, class TSessionData>
class Listener {
class Listener final {
private:
// The maximum number of events handled per execution thread is 1. This is
// because each event represents the start of a network request and it doesn't
@ -39,14 +40,27 @@ class Listener {
using SessionHandler = Session<TSession, TSessionData>;
public:
Listener(TSessionData &data, int inactivity_timeout_sec,
const std::string &service_name)
Listener(TSessionData &data, ServerContext *context,
int inactivity_timeout_sec, const std::string &service_name,
size_t workers_count)
: data_(data),
alive_(true),
context_(context),
inactivity_timeout_sec_(inactivity_timeout_sec),
service_name_(service_name) {
std::cout << "Starting " << workers_count << " " << service_name_
<< " workers" << std::endl;
for (size_t i = 0; i < workers_count; ++i) {
worker_threads_.emplace_back([this, service_name, i]() {
utils::ThreadSetName(fmt::format("{} worker {}", service_name, i + 1));
while (alive_) {
WaitAndProcessEvents();
}
});
}
if (inactivity_timeout_sec_ > 0) {
thread_ = std::thread([this, service_name]() {
timeout_thread_ = std::thread([this, service_name]() {
utils::ThreadSetName(fmt::format("{} timeout", service_name));
while (alive_) {
{
@ -72,7 +86,10 @@ class Listener {
~Listener() {
alive_.store(false);
if (thread_.joinable()) thread_.join();
if (timeout_thread_.joinable()) timeout_thread_.join();
for (auto &worker_thread : worker_threads_) {
worker_thread.join();
}
}
Listener(const Listener &) = delete;
@ -88,20 +105,12 @@ class Listener {
void AddConnection(io::network::Socket &&connection) {
std::unique_lock<utils::SpinLock> guard(lock_);
// Set connection options.
// The socket is left to be a blocking socket, but when `Read` is called
// then a flag is manually set to enable non-blocking read that is used in
// conjunction with `EPOLLET`. That means that the socket is used in a
// non-blocking fashion for reads and a blocking fashion for writes.
connection.SetKeepAlive();
connection.SetNoDelay();
// Remember fd before moving connection into Session.
int fd = connection.fd();
// Create a new Session for the connection.
sessions_.push_back(std::make_unique<SessionHandler>(
std::move(connection), data_, inactivity_timeout_sec_));
std::move(connection), data_, context_, inactivity_timeout_sec_));
// Register the connection in Epoll.
// We want to listen to an incoming event which is edge triggered and
@ -113,6 +122,7 @@ class Listener {
sessions_.back().get());
}
private:
/**
* This function polls the event queue and processes incoming data.
* It is thread safe and is intended to be called from multiple threads and
@ -168,7 +178,6 @@ class Listener {
}
}
private:
bool ExecuteSession(SessionHandler &session) {
try {
if (session.Execute()) {
@ -223,8 +232,11 @@ class Listener {
utils::SpinLock lock_;
std::vector<std::unique_ptr<SessionHandler>> sessions_;
std::thread thread_;
std::thread timeout_thread_;
std::vector<std::thread> worker_threads_;
std::atomic<bool> alive_;
ServerContext *context_;
const int inactivity_timeout_sec_;
const std::string service_name_;
};

View File

@ -30,7 +30,7 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
// Connect to the remote server.
if (!client_) {
client_.emplace();
client_.emplace(&context_);
if (!client_->Connect(endpoint_)) {
LOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
client_ = std::experimental::nullopt;

View File

@ -86,6 +86,8 @@ class Client {
::capnp::MessageBuilder *message);
io::network::Endpoint endpoint_;
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
communication::ClientContext context_;
std::experimental::optional<communication::Client> client_;
std::mutex mutex_;

View File

@ -4,7 +4,7 @@ namespace communication::rpc {
Server::Server(const io::network::Endpoint &endpoint,
size_t workers_count)
: server_(endpoint, *this, -1, "RPC", workers_count) {}
: server_(endpoint, *this, &context_, -1, "RPC", workers_count) {}
void Server::StopProcessingCalls() {
server_.Shutdown();

View File

@ -78,6 +78,8 @@ class Server {
ConcurrentMap<uint64_t, RpcCallback> callbacks_;
std::mutex mutex_;
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
communication::ServerContext context_;
communication::Server<Session, Server> server_;
}; // namespace communication::rpc

View File

@ -10,6 +10,7 @@
#include <fmt/format.h>
#include <glog/logging.h>
#include "communication/init.hpp"
#include "communication/listener.hpp"
#include "io/network/socket.hpp"
#include "utils/thread.hpp"
@ -27,6 +28,9 @@ namespace communication {
* Current Server achitecture:
* incoming connection -> server -> listener -> session
*
* NOTE: If you use this server you **must** call `communication::Init()` from
* the `main` function before using the server!
*
* @tparam TSession the server can handle different Sessions, each session
* represents a different protocol so the same network infrastructure
* can be used for handling different protocols
@ -34,7 +38,7 @@ namespace communication {
* session
*/
template <typename TSession, typename TSessionData>
class Server {
class Server final {
public:
using Socket = io::network::Socket;
@ -43,9 +47,11 @@ class Server {
* invokes workers_count workers
*/
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
int inactivity_timeout_sec, const std::string &service_name,
ServerContext *context, int inactivity_timeout_sec,
const std::string &service_name,
size_t workers_count = std::thread::hardware_concurrency())
: listener_(session_data, inactivity_timeout_sec, service_name),
: listener_(session_data, context, inactivity_timeout_sec, service_name,
workers_count),
service_name_(service_name) {
// Without server we can't continue with application so we can just
// terminate here.
@ -58,18 +64,7 @@ class Server {
}
thread_ = std::thread([this, workers_count, service_name]() {
std::cout << "Starting " << workers_count << " " << service_name
<< " workers" << std::endl;
utils::ThreadSetName(fmt::format("{} server", service_name));
for (size_t i = 0; i < workers_count; ++i) {
worker_threads_.emplace_back([this, service_name, i]() {
utils::ThreadSetName(
fmt::format("{} worker {}", service_name, i + 1));
while (alive_) {
listener_.WaitAndProcessEvents();
}
});
}
std::cout << service_name << " server is fully armed and operational"
<< std::endl;
@ -81,9 +76,6 @@ class Server {
}
std::cout << service_name << " shutting down..." << std::endl;
for (auto &worker_thread : worker_threads_) {
worker_thread.join();
}
});
}
@ -128,7 +120,6 @@ class Server {
std::atomic<bool> alive_{true};
std::thread thread_;
std::vector<std::thread> worker_threads_;
Socket socket_;
Listener<TSession, TSessionData> listener_;

View File

@ -9,7 +9,13 @@
#include <glog/logging.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include "communication/buffer.hpp"
#include "communication/context.hpp"
#include "communication/helpers.hpp"
#include "io/network/socket.hpp"
#include "io/network/stream_buffer.hpp"
#include "utils/exceptions.hpp"
@ -35,12 +41,19 @@ using InputStream = Buffer::ReadEnd;
* This is used to provide output from user sessions. All sessions used with the
* network stack should use this class for their output stream.
*/
class OutputStream {
class OutputStream final {
public:
OutputStream(io::network::Socket &socket) : socket_(socket) {}
OutputStream(
std::function<bool(const uint8_t *, size_t, bool)> write_function)
: write_function_(write_function) {}
OutputStream(const OutputStream &) = delete;
OutputStream(OutputStream &&) = delete;
OutputStream &operator=(const OutputStream &) = delete;
OutputStream &operator=(OutputStream &&) = delete;
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
return socket_.Write(data, len, have_more);
return write_function_(data, len, have_more);
}
bool Write(const std::string &str, bool have_more = false) {
@ -49,7 +62,7 @@ class OutputStream {
}
private:
io::network::Socket &socket_;
std::function<bool(const uint8_t *, size_t, bool)> write_function_;
};
/**
@ -58,20 +71,73 @@ class OutputStream {
* wrapping.
*/
template <class TSession, class TSessionData>
class Session {
class Session final {
public:
Session(io::network::Socket &&socket, TSessionData &data,
int inactivity_timeout_sec)
ServerContext *context, int inactivity_timeout_sec)
: socket_(std::move(socket)),
output_stream_(socket_),
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
return Write(data, len, have_more);
}),
session_(data, input_buffer_.read_end(), output_stream_),
inactivity_timeout_sec_(inactivity_timeout_sec) {}
inactivity_timeout_sec_(inactivity_timeout_sec) {
// Set socket options.
// The socket is set to be a non-blocking socket. We use the socket in a
// non-blocking fashion for reads and manually simulate a blocking socket
// type for writes. This manual handling of writes is necessary because
// OpenSSL doesn't provide a way to add `recv` parameters to the `SSL_read`
// call so we can't have a blocking socket and use it in a non-blocking way
// only for reads.
// Keep alive is enabled so that the Kernel's TCP stack notifies us if a
// connection is broken and shouldn't be used anymore.
// Because we manually always set the `have_more` flag to the socket
// `Write` call we can disable the Nagle algorithm because we know that we
// are always sending optimal packets. Even if we don't send optimal
// packets, there will be no delay between packets and throughput won't
// suffer.
socket_.SetNonBlocking();
socket_.SetKeepAlive();
socket_.SetNoDelay();
// Prepare SSL if we should be using it.
if (context->use_ssl()) {
// Create a new SSL object that will be used for SSL communication.
ssl_ = SSL_new(context->context());
CHECK(ssl_ != nullptr) << "Couldn't create server SSL object!";
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
// it doesn't need to close the socket when destructing all objects (we
// handle that in our socket destructor).
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
CHECK(bio_ != nullptr) << "Couldn't create server BIO object!";
// Connect the BIO object to the SSL object so that OpenSSL knows which
// stream it should use for communication. We use the same object for both
// the read and write end. This function cannot fail.
SSL_set_bio(ssl_, bio_, bio_);
// Indicate to OpenSSL that this connection is a server. The TLS handshake
// will be performed in the first `SSL_read` or `SSL_write` call. This
// function cannot fail.
SSL_set_accept_state(ssl_);
}
}
Session(const Session &) = delete;
Session(Session &&) = delete;
Session &operator=(const Session &) = delete;
Session &operator=(Session &&) = delete;
~Session() {
// If we are using SSL we need to free the allocated objects. Here we only
// free the SSL object because the `SSL_free` function also automatically
// frees the BIO object.
if (ssl_) {
SSL_free(ssl_);
}
}
/**
* This function is called from the communication stack when an event occurs
* indicating that there is data waiting to be read. This function calls the
@ -87,8 +153,48 @@ class Session {
// Allocate the buffer to fill the data.
auto buf = input_buffer_.write_end().Allocate();
if (ssl_) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Read data from the socket using the OpenSSL API.
auto len = SSL_read(ssl_, buf.data, buf.len);
// Check for read errors.
if (len < 0) {
auto err = SSL_get_error(ssl_, len);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL want's to read more data from the socket. We return `true`
// to stop execution of the session to wait for more data to be
// received.
return true;
} else if (err == SSL_ERROR_WANT_WRITE) {
// The OpenSSL library wants to perfrom some kind of handshake so we
// wait for the socket to become ready for a write and call the read
// again. We return `false` so that the listener calls this function
// again.
socket_.WaitForReadyWrite();
return false;
} else {
// This is a fatal error.
throw utils::BasicException(SslGetLastError());
}
} else if (len == 0) {
// The client closed the connection.
throw SessionClosedException("Session was closed by the client.");
return false;
} else {
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
}
} else {
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
int len = socket_.Read(buf.data, buf.len, true);
// Note, the `true` parameter for non-blocking here is redundant because
// the socket already is non-blocking.
auto len = socket_.Read(buf.data, buf.len, true);
// Check for read errors.
if (len == -1) {
@ -98,17 +204,17 @@ class Session {
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
return true;
}
// Some other error occurred, throw an exception to start session cleanup.
throw utils::BasicException("Couldn't read data from socket!");
}
// Some other error occurred, throw an exception to start session
// cleanup.
throw utils::BasicException("Couldn't read data from the socket!");
} else if (len == 0) {
// The client has closed the connection.
if (len == 0) {
throw SessionClosedException("Session was closed by client.");
}
} else {
// Notify the input buffer that it has new data.
input_buffer_.write_end().Written(len);
}
}
// Execute the session.
session_.Execute();
@ -142,6 +248,52 @@ class Session {
last_event_time_ = std::chrono::steady_clock::now();
}
// TODO (mferencevic): the `have_more` flag currently isn't supported
// when using OpenSSL
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
if (ssl_) {
// `SSL_write` has the interface of a normal `write` call. Because of that
// we need to ensure that all data is written to the socket manually.
while (len > 0) {
// We clear errors here to prevent errors piling up in the internal
// OpenSSL error queue. To see when could that be an issue read this:
// https://www.arangodb.com/2014/07/started-hate-openssl/
ERR_clear_error();
// Write data to the socket using OpenSSL.
auto written = SSL_write(ssl_, data, len);
if (written < 0) {
auto err = SSL_get_error(ssl_, written);
if (err == SSL_ERROR_WANT_READ) {
// OpenSSL wants to perform some kind of handshake, we need to
// ensure that there is data available for the next call to
// `SSL_write`.
socket_.WaitForReadyRead();
} else if (err == SSL_ERROR_WANT_WRITE) {
// The socket probably returned WOULDBLOCK and we need to wait for
// the output buffers to clear and reattempt the send.
socket_.WaitForReadyWrite();
} else {
// This is a fatal error.
return false;
}
} else if (written == 0) {
// The client closed the connection.
return false;
} else {
len -= written;
data += written;
}
}
return true;
} else {
// This function guarantees that all data will be written to the socket
// even if the socket is non-blocking. It will use a non-busy wait to send
// all data.
return socket_.Write(data, len, have_more);
}
}
// We own the socket.
io::network::Socket socket_;
@ -157,5 +309,9 @@ class Session {
std::chrono::steady_clock::now()};
utils::SpinLock lock_;
const int inactivity_timeout_sec_;
};
// SSL objects.
SSL *ssl_{nullptr};
BIO *bio_{nullptr};
}; // namespace communication
} // namespace communication

View File

@ -11,6 +11,7 @@
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/types.h>
@ -219,8 +220,13 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
// Terminal error, return failure.
return false;
}
// Non-fatal error, retry.
continue;
// Non-fatal error, retry after the socket is ready. This is here to
// implement a non-busy wait. If we just continue with the loop we have a
// busy wait.
if (!WaitForReadyWrite()) return false;
} else if (written == 0) {
// The client closed the connection.
return false;
} else {
len -= written;
data += written;
@ -234,7 +240,32 @@ bool Socket::Write(const std::string &s, bool have_more) {
have_more);
}
int Socket::Read(void *buffer, size_t len, bool nonblock) {
ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) {
return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0);
}
bool Socket::WaitForReadyRead() {
struct pollfd p;
p.fd = socket_;
p.events = POLLIN;
// We call poll with one element in the poll fds array (first and second
// arguments), also we set the timeout to -1 to block indefinitely until an
// event occurs.
int ret = poll(&p, 1, -1);
if (ret < 1) return false;
return p.revents & POLLIN;
}
bool Socket::WaitForReadyWrite() {
struct pollfd p;
p.fd = socket_;
p.events = POLLOUT;
// We call poll with one element in the poll fds array (first and second
// arguments), also we set the timeout to -1 to block indefinitely until an
// event occurs.
int ret = poll(&p, 1, -1);
if (ret < 1) return false;
return p.revents & POLLOUT;
}
} // namespace io::network

View File

@ -153,7 +153,41 @@ class Socket {
* == 0 if the client closed the connection
* < 0 if an error has occurred
*/
int Read(void *buffer, size_t len, bool nonblock = false);
ssize_t Read(void *buffer, size_t len, bool nonblock = false);
/**
* Wait until the socket becomes ready for a `Read` operation.
* This function blocks indefinitely waiting for the socket to change its
* state. This function is useful when you need a blocking operation on a
* non-blocking socket, you can call this function to ensure that your next
* `Read` operation will succeed.
*
* The function returns `true` if the wait succeded (there is data waiting to
* be read from the socket) and returns `false` if the wait failed (the socket
* was closed or something else bad happened).
*
* @return wait success status:
* true if the wait succeeded
* false if the wait failed
*/
bool WaitForReadyRead();
/**
* Wait until the socket becomes ready for a `Write` operation.
* This function blocks indefinitely waiting for the socket to change its
* state. This function is useful when you need a blocking operation on a
* non-blocking socket, you can call this function to ensure that your next
* `Write` operation will succeed.
*
* The function returns `true` if the wait succeded (the socket can be written
* to) and returns `false` if the wait failed (the socket was closed or
* something else bad happened).
*
* @return wait success status:
* true if the wait succeeded
* false if the wait failed
*/
bool WaitForReadyWrite();
private:
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}

View File

@ -28,6 +28,7 @@ using communication::bolt::SessionData;
using SessionT = communication::bolt::Session<communication::InputStream,
communication::OutputStream>;
using ServerT = communication::Server<SessionT, SessionData>;
using communication::ServerContext;
// General purpose flags.
DEFINE_string(interface, "0.0.0.0",
@ -41,6 +42,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
"Time in seconds after which inactive sessions will be "
"closed.",
FLAG_IN_RANGE(1, INT32_MAX));
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
DEFINE_string(log_file, "", "Path to where the log should be stored.");
DEFINE_HIDDEN_string(
log_link_basename, "",
@ -142,6 +145,9 @@ int WithInit(int argc, char **argv,
stats::InitStatsLogging(get_stats_prefix());
utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); });
// Initialize the communication library.
communication::Init();
// Start memory warning logger.
utils::Scheduler mem_log_scheduler;
if (FLAGS_memory_warning_threshold > 0) {
@ -160,9 +166,17 @@ void SingleNodeMain() {
google::SetUsageMessage("Memgraph single-node database server");
database::SingleNode db;
SessionData session_data{db};
ServerContext context;
std::string service_name = "Bolt";
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
service_name = "BoltS";
}
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, FLAGS_session_inactivity_timeout, "Bolt",
FLAGS_num_workers);
session_data, &context, FLAGS_session_inactivity_timeout,
service_name, FLAGS_num_workers);
// Setup telemetry
std::experimental::optional<telemetry::Telemetry> telemetry;
@ -214,9 +228,17 @@ void MasterMain() {
database::Master db;
SessionData session_data{db};
ServerContext context;
std::string service_name = "Bolt";
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
service_name = "BoltS";
}
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
session_data, FLAGS_session_inactivity_timeout, "Bolt",
FLAGS_num_workers);
session_data, &context, FLAGS_session_inactivity_timeout,
service_name, FLAGS_num_workers);
// Handler for regular termination signals
auto shutdown = [&server] {

View File

@ -42,10 +42,11 @@ class TestSession {
input_stream_.Shift(size + 2);
}
communication::InputStream input_stream_;
communication::OutputStream output_stream_;
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
};
using ContextT = communication::ServerContext;
using ServerT = communication::Server<TestSession, TestData>;
void client_run(int num, const char *interface, uint16_t port,

View File

@ -32,8 +32,8 @@ class TestSession {
output_stream_.Write(input_stream_.data(), input_stream_.size());
}
communication::InputStream input_stream_;
communication::OutputStream output_stream_;
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
};
std::atomic<bool> run{true};
@ -63,8 +63,9 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
TestData data;
int N = (std::thread::hardware_concurrency() + 1) / 2;
int Nc = N * 3;
communication::Server<TestSession, TestData> server(endpoint, data, -1,
"Test", N);
communication::ServerContext context;
communication::Server<TestSession, TestData> server(endpoint, data, &context,
-1, "Test", N);
const auto &ep = server.endpoint();
// start clients

View File

@ -21,7 +21,8 @@ TEST(Network, Server) {
// initialize server
TestData session_data;
int N = (std::thread::hardware_concurrency() + 1) / 2;
ServerT server(endpoint, session_data, -1, "Test", N);
ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", N);
const auto &ep = server.endpoint();
// start clients

View File

@ -22,7 +22,8 @@ TEST(Network, SessionLeak) {
// initialize server
TestData session_data;
ServerT server(endpoint, session_data, -1, "Test", 2);
ContextT context;
ServerT server(endpoint, session_data, &context, -1, "Test", 2);
// start clients
int N = 50;

View File

@ -1,2 +1,5 @@
# telemetry test binaries
add_subdirectory(telemetry)
# ssl test binaries
add_subdirectory(ssl)

View File

@ -6,3 +6,11 @@
- server.py # server script
- ../../../build_debug/tests/integration/telemetry/client # client binary
- ../../../build_debug/tests/manual/kvstore_console # kvstore console binary
- name: integration__ssl
cd: ssl
commands: ./runner.sh
infiles:
- runner.sh # runner script
- ../../../build_debug/tests/integration/ssl/tester # tester binary
enable_network: true

View File

@ -0,0 +1,6 @@
set(target_name memgraph__integration__ssl)
set(tester_target_name ${target_name}__tester)
add_executable(${tester_target_name} tester.cpp)
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
target_link_libraries(${tester_target_name} memgraph_lib kvstore_dummy_lib)

78
tests/integration/ssl/runner.sh Executable file
View File

@ -0,0 +1,78 @@
#!/bin/bash -e
pushd () { command pushd "$@" > /dev/null; }
popd () { command popd "$@" > /dev/null; }
die () { printf "\033[1;31m~~ Test failed! ~~\033[0m\n\n"; exit 1; }
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$DIR"
tmpdir=/tmp/memgraph_integration_ssl
if [ -d $tmpdir ]; then
rm -rf $tmpdir
fi
mkdir -p $tmpdir
cd $tmpdir
easyrsa="EasyRSA-3.0.4"
wget -nv http://deps.memgraph.io/$easyrsa.tgz
tar -xf $easyrsa.tgz
mv $easyrsa ca1
cp -r ca1 ca2
for i in ca1 ca2; do
pushd $i
./easyrsa --batch init-pki
./easyrsa --batch build-ca nopass
./easyrsa --batch build-server-full server nopass
./easyrsa --batch build-client-full client nopass
popd
done
binary_dir="$DIR/../../../build"
if [ ! -d $binary_dir ]; then
binary_dir="$DIR/../../../build_debug"
fi
set +e
echo
for i in ca1 ca2; do
for j in none ca1 ca2; do
for k in false true; do
printf "\033[1;36m~~ Server CA: $i; Client CA: $j; Verify peer: $k ~~\033[0m\n"
if [ "$j" == "none" ]; then
client_key=""
client_cert=""
else
client_key=$j/pki/private/client.key
client_cert=$j/pki/issued/client.crt
fi
$binary_dir/tests/integration/ssl/tester \
--server-key-file=$i/pki/private/server.key \
--server-cert-file=$i/pki/issued/server.crt \
--server-ca-file=$i/pki/ca.crt \
--server-verify-peer=$k \
--client-key-file=$client_key \
--client-cert-file=$client_cert
exitcode=$?
if [ "$i" == "$j" ]; then
[ $exitcode -eq 0 ] || die
else
if $k; then
[ $exitcode -ne 0 ] || die
else
[ $exitcode -eq 0 ] || die
fi
fi
printf "\033[1;32m~~ Test ok! ~~\033[0m\n\n"
done
done
done

View File

@ -0,0 +1,76 @@
#include <atomic>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/client.hpp"
#include "communication/server.hpp"
#include "utils/exceptions.hpp"
DEFINE_string(server_cert_file, "", "Server certificate file to use.");
DEFINE_string(server_key_file, "", "Server key file to use.");
DEFINE_string(server_ca_file, "", "Server CA file to use.");
DEFINE_bool(server_verify_peer, false, "Set to true to verify the peer.");
DEFINE_string(client_cert_file, "", "Client certificate file to use.");
DEFINE_string(client_key_file, "", "Client key file to use.");
const std::string message = "ssl echo test";
struct EchoData {};
class EchoSession {
public:
EchoSession(EchoData &, communication::InputStream &input_stream,
communication::OutputStream &output_stream)
: input_stream_(input_stream), output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < message.size()) return;
LOG(INFO) << "Server received message.";
if (!output_stream_.Write(input_stream_.data(), message.size())) {
throw utils::BasicException("Output stream write failed!");
}
LOG(INFO) << "Server sent message.";
input_stream_.Shift(message.size());
}
private:
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
};
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
// Initialize the communication stack.
communication::Init();
// Initialize the server.
EchoData echo_data;
communication::ServerContext server_context(
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
FLAGS_server_verify_peer);
communication::Server<EchoSession, EchoData> server(
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1);
// Initialize the client.
communication::ClientContext client_context(FLAGS_client_key_file,
FLAGS_client_cert_file);
communication::Client client(&client_context);
// Connect to the server.
CHECK(client.Connect(server.endpoint())) << "Couldn't connect to server!";
// Perform echo.
CHECK(client.Write(message)) << "Client couldn't send message!";
LOG(INFO) << "Client sent message.";
CHECK(client.Read(message.size())) << "Client couldn't receive message!";
LOG(INFO) << "Client received message.";
CHECK(std::string(reinterpret_cast<const char *>(client.GetData()),
message.size()) == message)
<< "Received message isn't equal to sent message!";
return 0;
}

View File

@ -10,6 +10,7 @@
#include "io/network/endpoint.hpp"
using EndpointT = io::network::Endpoint;
using ContextT = communication::ClientContext;
using ClientT = communication::bolt::Client;
using QueryDataT = communication::bolt::QueryData;
using communication::bolt::DecodedValue;
@ -18,23 +19,23 @@ class BoltClient {
public:
BoltClient(const std::string &address, uint16_t port,
const std::string &username, const std::string &password,
const std::string & = "") {
const std::string & = "", bool use_ssl = false)
: context_(use_ssl), client_(context_) {
EndpointT endpoint(address, port);
client_ = std::make_unique<ClientT>();
if (!client_->Connect(endpoint, username, password)) {
if (!client_.Connect(endpoint, username, password)) {
LOG(FATAL) << "Could not connect to: " << endpoint;
}
}
QueryDataT Execute(const std::string &query,
const std::map<std::string, DecodedValue> &parameters) {
return client_->Execute(query, parameters);
return client_.Execute(query, parameters);
}
void Close() { client_->Close(); }
void Close() { client_.Close(); }
private:
std::unique_ptr<ClientT> client_;
ContextT context_;
ClientT client_;
};

View File

@ -331,11 +331,14 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
stats::InitStatsLogging(
fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario));
Endpoint endpoint(FLAGS_address, FLAGS_port);
Client client;
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}

View File

@ -11,6 +11,7 @@
#include "utils/exceptions.hpp"
#include "utils/timer.hpp"
using communication::ClientContext;
using communication::bolt::Client;
using communication::bolt::DecodedValue;
using io::network::Endpoint;

View File

@ -16,6 +16,7 @@ DEFINE_int32(num_workers, 1, "Number of workers");
DEFINE_string(output, "", "Output file");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_int32(duration, 30, "Number of seconds to execute benchmark");
DEFINE_string(group, "unknown", "Test group name");
@ -97,7 +98,8 @@ class TestClient {
std::thread runner_thread_;
private:
Client client_;
communication::ClientContext context_{FLAGS_use_ssl};
Client client_{&context_};
};
void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) {

View File

@ -271,12 +271,15 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
nlohmann::json config;
std::cin >> config;
auto independent_nodes_ids = [&] {
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
Client client;
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}

View File

@ -19,6 +19,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
using communication::bolt::DecodedValue;
@ -58,7 +59,8 @@ void ExecuteQueries(const std::vector<std::string> &queries,
for (int i = 0; i < FLAGS_num_workers; ++i) {
threads.push_back(std::thread([&]() {
Endpoint endpoint(FLAGS_address, FLAGS_port);
Client client;
ClientContext context(FLAGS_use_ssl);
Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
LOG(FATAL) << "Couldn't connect to " << endpoint;
}
@ -100,6 +102,8 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
std::ifstream ifile;
std::istream *istream{&std::cin};

View File

@ -71,6 +71,12 @@ target_link_libraries(${test_prefix}sl_position_and_count memgraph_lib kvstore_d
add_manual_test(stripped_timing.cpp)
target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib)
add_manual_test(ssl_client.cpp)
target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib)
add_manual_test(ssl_server.cpp)
target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib)
add_manual_test(xorshift.cpp)
target_link_libraries(${test_prefix}xorshift mg-utils)

View File

@ -10,15 +10,20 @@ DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
// TODO: handle endpoint exception
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
FLAGS_port);
communication::bolt::Client client;
communication::ClientContext context(FLAGS_use_ssl);
communication::bolt::Client client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;

View File

@ -0,0 +1,65 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/client.hpp"
#include "io/network/endpoint.hpp"
#include "utils/timer.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 54321, "Server port");
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
bool EchoMessage(communication::Client &client, const std::string &data) {
uint16_t size = data.size();
if (!client.Write(reinterpret_cast<const uint8_t *>(&size), sizeof(size))) {
LOG(WARNING) << "Couldn't send data size!";
return false;
}
if (!client.Write(data)) {
LOG(WARNING) << "Couldn't send data!";
return false;
}
client.ClearData();
if (!client.Read(size)) {
LOG(WARNING) << "Couldn't receive data!";
return false;
}
if (std::string(reinterpret_cast<const char *>(client.GetData()), size) !=
data) {
LOG(WARNING) << "Received data isn't equal to sent data!";
return false;
}
return true;
}
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
communication::ClientContext context(FLAGS_key_file, FLAGS_cert_file);
communication::Client client(&context);
if (!client.Connect(endpoint)) return 1;
bool success = true;
while (true) {
std::string s;
std::getline(std::cin, s);
if (s == "") break;
if (!EchoMessage(client, s)) {
success = false;
break;
}
}
// Send server shutdown signal. The call will fail, we don't care.
EchoMessage(client, "");
return success ? 0 : 1;
}

View File

@ -0,0 +1,73 @@
#include <atomic>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/server.hpp"
#include "utils/exceptions.hpp"
DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 54321, "Server port");
DEFINE_string(cert_file, "", "Certificate file to use.");
DEFINE_string(key_file, "", "Key file to use.");
DEFINE_string(ca_file, "", "CA file to use.");
DEFINE_bool(verify_peer, false, "Set to true to verify the peer.");
struct EchoData {
std::atomic<bool> alive{true};
};
class EchoSession {
public:
EchoSession(EchoData &data, communication::InputStream &input_stream,
communication::OutputStream &output_stream)
: data_(data),
input_stream_(input_stream),
output_stream_(output_stream) {}
void Execute() {
if (input_stream_.size() < 2) return;
const uint8_t *data = input_stream_.data();
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data());
input_stream_.Resize(size + 2);
if (input_stream_.size() < size + 2) return;
if (size == 0) {
LOG(INFO) << "Server received EOF message";
data_.alive.store(false);
return;
}
LOG(INFO) << "Server received '"
<< std::string(reinterpret_cast<const char *>(data + 2), size)
<< "'";
if (!output_stream_.Write(data + 2, size)) {
throw utils::BasicException("Output stream write failed!");
}
input_stream_.Shift(size + 2);
}
private:
EchoData &data_;
communication::InputStream &input_stream_;
communication::OutputStream &output_stream_;
};
int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
// Initialize the server.
EchoData echo_data;
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
FLAGS_ca_file, FLAGS_verify_peer);
communication::Server<EchoSession, EchoData> server(endpoint, echo_data,
&context, -1, "SSL", 1);
while (echo_data.alive) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
return 0;
}

View File

@ -1 +1,2 @@
.long_running_stats
*.pem

View File

@ -10,6 +10,10 @@
commands: TIMEOUT=600 ./continuous_integration --properties-on-disk
infiles: *STRESS_INFILES
- name: stress_ssl
commands: TIMEOUT=600 ./continuous_integration --use-ssl
infiles: *STRESS_INFILES
- name: stress_large
project: release
commands: TIMEOUT=43200 ./continuous_integration --large-dataset

View File

@ -146,14 +146,14 @@ def connection_argument_parser():
'''
parser = ArgumentParser(description=__doc__)
parser.add_argument('--endpoint', type=str, default='localhost:7687',
parser.add_argument('--endpoint', type=str, default='127.0.0.1:7687',
help='DBMS instance endpoint. '
'Bolt protocol is the only option.')
parser.add_argument('--username', type=str, default='neo4j',
help='DBMS instance username.')
parser.add_argument('--password', type=int, default='1234',
help='DBMS instance password.')
parser.add_argument('--ssl-enabled', action='store_true',
parser.add_argument('--use-ssl', action='store_true',
help="Is SSL enabled?")
return parser
@ -163,7 +163,7 @@ def bolt_session(url, auth, ssl=False):
'''
with wrapper around Bolt session.
:param url: str, e.g. "bolt://localhost:7687"
:param url: str, e.g. "bolt://127.0.0.1:7687"
:param auth: auth method, goes directly to the Bolt driver constructor
:param ssl: bool, is ssl enabled
'''
@ -183,14 +183,15 @@ def argument_session(args):
:return: Bolt session context manager based on program arguments
'''
return bolt_session('bolt://' + args.endpoint,
(args.username, str(args.password)))
(args.username, str(args.password)),
args.use_ssl)
def argument_driver(args, ssl=False):
def argument_driver(args):
return GraphDatabase.driver(
'bolt://' + args.endpoint,
auth=(args.username, str(args.password)),
encrypted=ssl)
encrypted=args.use_ssl)
# This class is used to create and cache sessions. Session is cached by args
# used to create it and process' pid in which it was created. This makes it easy

View File

@ -73,6 +73,8 @@ BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
BUILD_DIR = os.path.join(BASE_DIR, "build")
CONFIG_DIR = os.path.join(BASE_DIR, "config")
MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements")
KEY_FILE = os.path.join(SCRIPT_DIR, ".key.pem")
CERT_FILE = os.path.join(SCRIPT_DIR, ".cert.pem")
# long running stats file
STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats")
@ -132,6 +134,8 @@ parser.add_argument("--python", default = os.path.join(SCRIPT_DIR,
"ve3", "bin", "python3"), type = str)
parser.add_argument("--large-dataset", action = "store_const",
const = True, default = False)
parser.add_argument("--use-ssl", action = "store_const",
const = True, default = False)
parser.add_argument("--verbose", action = "store_const",
const = True, default = False)
args = parser.parse_args()
@ -140,6 +144,14 @@ args = parser.parse_args()
if not os.path.exists(args.memgraph):
args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph")
# generate temporary SSL certs
if args.use_ssl:
# https://unix.stackexchange.com/questions/104171/create-ssl-certificate-non-interactively
subj = "/C=HR/ST=Zagreb/L=Zagreb/O=Memgraph/CN=db.memgraph.com"
subprocess.run(["openssl", "req", "-new", "-newkey", "rsa:4096",
"-days", "365", "-nodes", "-x509", "-subj", subj,
"-keyout", KEY_FILE, "-out", CERT_FILE], check=True)
# start memgraph
cwd = os.path.dirname(args.memgraph)
cmd = [args.memgraph, "--num-workers=" + str(THREADS)]
@ -151,6 +163,8 @@ if args.durability_directory:
cmd += ["--durability-directory", args.durability_directory]
if args.properties_on_disk:
cmd += ["--properties-on-disk", "id,x"]
if args.use_ssl:
cmd += ["--cert-file", CERT_FILE, "--key-file", KEY_FILE]
proc_mg = subprocess.Popen(cmd, cwd = cwd,
env = {"MEMGRAPH_CONFIG": args.config})
time.sleep(1.0)
@ -167,6 +181,8 @@ def cleanup():
runtimes = {}
dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET
for test in dataset:
if args.use_ssl:
test["options"] += ["--use-ssl"]
runtime = run_test(args, **test)
runtimes[os.path.splitext(test["test"])[0]] = runtime
@ -176,6 +192,11 @@ ret_mg = proc_mg.wait()
if ret_mg != 0:
raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg))
# cleanup certificates
if args.use_ssl:
os.remove(KEY_FILE)
os.remove(CERT_FILE)
# measurements
measurements = ""
for key, value in runtimes.items():

View File

@ -8,6 +8,7 @@
#include "utils/timer.hpp"
using EndpointT = io::network::Endpoint;
using ClientContextT = communication::ClientContext;
using ClientT = communication::bolt::Client;
using DecodedValueT = communication::bolt::DecodedValue;
using QueryDataT = communication::bolt::QueryData;
@ -17,6 +18,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
DEFINE_int32(port, 7687, "Server port");
DEFINE_string(username, "", "Username for the database");
DEFINE_string(password, "", "Password for the database");
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
DEFINE_int32(vertex_count, 0,
"The average number of vertices in the graph per worker");
@ -51,7 +53,7 @@ class GraphSession {
}
EndpointT endpoint(FLAGS_address, FLAGS_port);
client_ = std::make_unique<ClientT>();
client_ = std::make_unique<ClientT>(&context_);
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to server!");
@ -60,6 +62,7 @@ class GraphSession {
private:
uint64_t id_;
ClientContextT context_{FLAGS_use_ssl};
std::unique_ptr<ClientT> client_;
std::set<uint64_t> vertices_;
@ -362,6 +365,8 @@ int main(int argc, char **argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
communication::Init();
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";
@ -369,7 +374,8 @@ int main(int argc, char **argv) {
// create client
EndpointT endpoint(FLAGS_address, FLAGS_port);
ClientT client;
ClientContextT context(FLAGS_use_ssl);
ClientT client(&context);
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
throw utils::BasicException("Couldn't connect to server!");
}

View File

@ -52,8 +52,9 @@ bool QueryServer(io::network::Socket &socket) {
TEST(NetworkTimeouts, InactiveSession) {
// Instantiate the server and set the session timeout to 2 seconds.
TestData test_data;
communication::ServerContext context;
communication::Server<TestSession, TestData> server{
{"127.0.0.1", 0}, test_data, 2, "Test", 1};
{"127.0.0.1", 0}, test_data, &context, 2, "Test", 1};
// Create the client and connect to the server.
io::network::Socket client;

94
tests/unit/socket.cpp Normal file
View File

@ -0,0 +1,94 @@
#include <chrono>
#include <csignal>
#include <thread>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "io/network/socket.hpp"
#include "utils/timer.hpp"
TEST(Socket, WaitForReadyRead) {
io::network::Socket server;
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
ASSERT_TRUE(server.Listen(1024));
std::thread thread([&server] {
io::network::Socket client;
ASSERT_TRUE(client.Connect(server.endpoint()));
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
ASSERT_TRUE(client.Write("test"));
});
uint8_t buff[100];
auto client = server.Accept();
ASSERT_TRUE(client);
client->SetNonBlocking();
ASSERT_EQ(client->Read(buff, sizeof(buff)), -1);
ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR);
utils::Timer timer;
ASSERT_TRUE(client->WaitForReadyRead());
ASSERT_GT(timer.Elapsed().count(), 1.0);
ASSERT_GT(client->Read(buff, sizeof(buff)), 0);
thread.join();
}
TEST(Socket, WaitForReadyWrite) {
io::network::Socket server;
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
ASSERT_TRUE(server.Listen(1024));
std::thread thread([&server] {
uint8_t buff[10000];
io::network::Socket client;
ASSERT_TRUE(client.Connect(server.endpoint()));
client.SetNonBlocking();
// Decrease the TCP read buffer.
int len = 1024;
ASSERT_EQ(setsockopt(client.fd(), SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)),
0);
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
while (true) {
int ret = client.Read(buff, sizeof(buff));
if (ret == -1 && errno != EAGAIN && errno != EWOULDBLOCK &&
errno != EINTR) {
std::raise(SIGPIPE);
} else if (ret == 0) {
break;
}
}
});
auto client = server.Accept();
ASSERT_TRUE(client);
client->SetNonBlocking();
// Decrease the TCP write buffer.
int len = 1024;
ASSERT_EQ(setsockopt(client->fd(), SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)),
0);
utils::Timer timer;
for (int i = 0; i < 1000000; ++i) {
ASSERT_TRUE(client->Write("test"));
}
ASSERT_GT(timer.Elapsed().count(), 1.0);
client->Close();
thread.join();
}
int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}