Implement SSL support for servers and clients
Summary: This diff implements OpenSSL support in the network stack. Currently SSL support is only enabled for Bolt connections, support for RPC connections will be added in another diff. Reviewers: buda, teon.banek Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1328
This commit is contained in:
parent
44821a918c
commit
1d448d40ca
@ -10,6 +10,7 @@
|
||||
* Static vertices/edges id generators exposed through the Id Cypher function.
|
||||
* Properties on disk added.
|
||||
* Telemetry added.
|
||||
* SSL support added.
|
||||
* Add `toString` function to openCypher
|
||||
|
||||
### Bug Fixes and Other Changes
|
||||
|
@ -140,6 +140,9 @@ endif()
|
||||
set(Boost_USE_STATIC_LIBS ON)
|
||||
find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization)
|
||||
|
||||
# OpenSSL
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
||||
set(libs_dir ${CMAKE_SOURCE_DIR}/libs)
|
||||
add_subdirectory(libs EXCLUDE_FROM_ALL)
|
||||
|
||||
@ -320,6 +323,8 @@ set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY}
|
||||
Contains Memgraph, the graph database. It aims to deliver developers the
|
||||
speed, simplicity and scale required to build the next generation of
|
||||
applications driver by real-time connected data.")
|
||||
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
|
||||
set(CPACK_DEBIAN_memgraph_PACKAGE_DEPENDS "openssl (>= 1.1.0)")
|
||||
|
||||
# RPM specific
|
||||
set(CPACK_RPM_PACKAGE_URL https://memgraph.com)
|
||||
@ -335,6 +340,8 @@ set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_SOURCE_DIR}/release/rpm/memgraph.spe
|
||||
set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database.
|
||||
It aims to deliver developers the speed, simplicity and scale required to build
|
||||
the next generation of applications driver by real-time connected data.")
|
||||
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
|
||||
set(CPACK_RPM_memgraph_PACKAGE_REQUIRES "openssl >= 1.0.0")
|
||||
|
||||
# All variables must be set before including.
|
||||
include(CPack)
|
||||
|
@ -16,6 +16,12 @@
|
||||
# Port the server should listen on.
|
||||
--port=7687
|
||||
|
||||
# Path to a SSL certificate file that should be used.
|
||||
--cert-file=/etc/memgraph/ssl/cert.pem
|
||||
|
||||
# Path to a SSL key file that should be used.
|
||||
--key-file=/etc/memgraph/ssl/key.pem
|
||||
|
||||
# Number of workers used by the Memgraph server. By default, this will be the
|
||||
# number of processing units available on the machine.
|
||||
# --num-workers=8
|
||||
|
@ -13,11 +13,9 @@ from neo4j.v1 import GraphDatabase, basic_auth
|
||||
|
||||
# Initialize and configure the driver.
|
||||
# * provide the correct URL where Memgraph is reachable;
|
||||
# * use an empty user name and password, and
|
||||
# * disable encryption (not supported).
|
||||
# * use an empty user name and password.
|
||||
driver = GraphDatabase.driver("bolt://localhost:7687",
|
||||
auth=basic_auth("", ""),
|
||||
encrypted=False)
|
||||
auth=basic_auth("", ""))
|
||||
|
||||
# Start a session in which queries are executed.
|
||||
session = driver.session()
|
||||
@ -51,9 +49,7 @@ The details about Java driver can be found
|
||||
[on GitHub](https://github.com/neo4j/neo4j-java-driver).
|
||||
|
||||
The example below is equivalent to Python example. Major difference is that
|
||||
`Config` object has to be created before the driver construction. Encryption
|
||||
has to be disabled by calling `withoutEncryption` method against the `Config`
|
||||
builder.
|
||||
`Config` object has to be created before the driver construction.
|
||||
|
||||
```java
|
||||
import org.neo4j.driver.v1.*;
|
||||
@ -64,7 +60,7 @@ import java.util.*;
|
||||
public class JavaQuickStart {
|
||||
public static void main(String[] args) {
|
||||
// Initialize driver.
|
||||
Config config = Config.build().withoutEncryption().toConfig();
|
||||
Config config = Config.build().toConfig();
|
||||
Driver driver = GraphDatabase.driver("bolt://localhost:7687",
|
||||
AuthTokens.basic("",""),
|
||||
config);
|
||||
@ -93,9 +89,7 @@ public class JavaQuickStart {
|
||||
The details about Javascript driver can be found
|
||||
[on GitHub](https://github.com/neo4j/neo4j-javascript-driver).
|
||||
|
||||
The Javascript example below is equivalent to Python and Java examples. SSL
|
||||
can be disabled by passing `{encrypted: 'ENCRYPTION_OFF'}` during the driver
|
||||
construction.
|
||||
The Javascript example below is equivalent to Python and Java examples.
|
||||
|
||||
Here is an example related to `Node.js`. Memgraph doesn't have integrated
|
||||
support for `WebSocket` which is required during the execution in any web
|
||||
@ -109,8 +103,7 @@ proxy port.
|
||||
```javascript
|
||||
var neo4j = require('neo4j-driver').v1;
|
||||
var driver = neo4j.driver("bolt://localhost:7687",
|
||||
neo4j.auth.basic("neo4j", "1234"),
|
||||
{encrypted: 'ENCRYPTION_OFF'});
|
||||
neo4j.auth.basic("neo4j", "1234"));
|
||||
var session = driver.session();
|
||||
|
||||
function die() {
|
||||
@ -146,8 +139,7 @@ run_query("MATCH (n) DETACH DELETE n", function (result) {
|
||||
|
||||
The C# driver is hosted
|
||||
[on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below
|
||||
performs the same work as all of the previous examples. Encryption is disabled
|
||||
by setting `EncryptionLevel.NONE` on the `Config`.
|
||||
performs the same work as all of the previous examples.
|
||||
|
||||
```csh
|
||||
using System;
|
||||
@ -158,7 +150,6 @@ public class Basic {
|
||||
public static void Main(string[] args) {
|
||||
// Initialize the driver.
|
||||
var config = Config.DefaultConfig;
|
||||
config.EncryptionLevel = EncryptionLevel.None;
|
||||
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config))
|
||||
using(var session = driver.Session())
|
||||
{
|
||||
@ -176,6 +167,18 @@ public class Basic {
|
||||
}
|
||||
```
|
||||
|
||||
### Secure Sockets Layer (SSL)
|
||||
|
||||
Secure connections are supported and enabled by default. The server initially
|
||||
ships with a self-signed testing certificate. The certificate can be replaced
|
||||
by editing the following parameters in `/etc/memgraph/memgraph.conf`:
|
||||
```
|
||||
--cert-file=/path/to/ssl/certificate.pem
|
||||
--key-file=/path/to/ssl/privatekey.pem
|
||||
```
|
||||
To disable SSL support and use insecure connections to the database you should
|
||||
set both parameters (`--cert-file` and `--key-file`) to empty values.
|
||||
|
||||
### Limitations
|
||||
|
||||
Memgraph is currently in early stage, and has a number of limitations we plan
|
||||
@ -186,9 +189,3 @@ to remove in future versions.
|
||||
Memgraph is currently single-user only. There is no way to control user
|
||||
privileges. The default user has read and write privileges over the whole
|
||||
database.
|
||||
|
||||
#### Secure Sockets Layer (SSL)
|
||||
|
||||
Secure connections are not supported. For this reason each client
|
||||
driver needs to be configured not to use encryption. Consult driver-specific
|
||||
guides for details.
|
||||
|
@ -183,7 +183,7 @@ After installing `neo4j-client`, connect to the running Memgraph instance by
|
||||
issuing the following shell command.
|
||||
|
||||
```bash
|
||||
neo4j-client --insecure -u "" -p "" localhost 7687
|
||||
neo4j-client -u "" -p "" localhost 7687
|
||||
```
|
||||
|
||||
After the client has started it should present a command prompt similar to:
|
||||
@ -191,7 +191,7 @@ After the client has started it should present a command prompt similar to:
|
||||
```bash
|
||||
neo4j-client 2.1.3
|
||||
Enter `:help` for usage hints.
|
||||
Connected to 'neo4j://@localhost:7687' (insecure)
|
||||
Connected to 'neo4j://@localhost:7687'
|
||||
neo4j>
|
||||
```
|
||||
|
||||
|
1
init
1
init
@ -8,6 +8,7 @@ required_pkgs=(git arcanist # source code control
|
||||
curl wget # for downloading libs
|
||||
uuid-dev default-jre-headless # required by antlr
|
||||
libreadline-dev # for memgraph console
|
||||
libssl-dev
|
||||
libboost-iostreams-dev
|
||||
libboost-serialization-dev
|
||||
python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests
|
||||
|
@ -29,6 +29,16 @@ case "$1" in
|
||||
chmod 750 /var/log/memgraph || exit 1
|
||||
# Make examples directory immutable (optional)
|
||||
chattr +i -R /usr/share/memgraph/examples || true
|
||||
|
||||
# Generate SSL certificates
|
||||
if [ ! -d /etc/memgraph/ssl ]; then
|
||||
mkdir /etc/memgraph/ssl || exit 1
|
||||
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
|
||||
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
|
||||
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
|
||||
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
|
||||
chmod 400 /etc/memgraph/ssl/* || exit 1
|
||||
fi
|
||||
;;
|
||||
|
||||
abort-upgrade|abort-remove|abort-deconfigure)
|
||||
|
@ -71,6 +71,16 @@ chown memgraph:adm /var/log/memgraph || exit 1
|
||||
chmod 750 /var/log/memgraph || exit 1
|
||||
# Make examples directory immutable (optional)
|
||||
chattr +i -R /usr/share/memgraph/examples || true
|
||||
|
||||
# Generate SSL certificates
|
||||
if [ ! -d /etc/memgraph/ssl ]; then
|
||||
mkdir /etc/memgraph/ssl || exit 1
|
||||
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
|
||||
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
|
||||
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
|
||||
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
|
||||
chmod 400 /etc/memgraph/ssl/* || exit 1
|
||||
fi
|
||||
@RPM_SYMLINK_POSTINSTALL@
|
||||
@CPACK_RPM_SPEC_POSTINSTALL@
|
||||
|
||||
|
@ -8,6 +8,10 @@ add_subdirectory(telemetry)
|
||||
# all memgraph src files
|
||||
set(memgraph_src_files
|
||||
communication/buffer.cpp
|
||||
communication/client.cpp
|
||||
communication/context.cpp
|
||||
communication/helpers.cpp
|
||||
communication/init.cpp
|
||||
communication/bolt/v1/decoder/decoded_value.cpp
|
||||
communication/rpc/client.cpp
|
||||
communication/rpc/protocol.cpp
|
||||
@ -189,6 +193,7 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
|
||||
# memgraph_lib depend on these libraries
|
||||
set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools
|
||||
antlr_opencypher_parser_lib dl glog gflags capnp kj
|
||||
${OPENSSL_LIBRARIES}
|
||||
${Boost_IOSTREAMS_LIBRARY_RELEASE}
|
||||
${Boost_SERIALIZATION_LIBRARY_RELEASE}
|
||||
mg-utils mg-io)
|
||||
@ -206,6 +211,7 @@ endif()
|
||||
# STATIC library used by memgraph executables
|
||||
add_library(memgraph_lib STATIC ${memgraph_src_files})
|
||||
target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS})
|
||||
target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR})
|
||||
add_dependencies(memgraph_lib generate_opencypher_parser)
|
||||
add_dependencies(memgraph_lib generate_lcp)
|
||||
add_dependencies(memgraph_lib generate_capnp)
|
||||
|
@ -32,9 +32,9 @@ struct QueryData {
|
||||
std::map<std::string, DecodedValue> metadata;
|
||||
};
|
||||
|
||||
class Client {
|
||||
class Client final {
|
||||
public:
|
||||
Client() {}
|
||||
explicit Client(communication::ClientContext *context) : client_(context) {}
|
||||
|
||||
Client(const Client &) = delete;
|
||||
Client(Client &&) = delete;
|
||||
|
@ -20,7 +20,7 @@ namespace communication {
|
||||
* stack where all execution when it is being done is being done on a single
|
||||
* thread.
|
||||
*/
|
||||
class Buffer {
|
||||
class Buffer final {
|
||||
private:
|
||||
// Initial capacity of the internal buffer.
|
||||
const size_t kBufferInitialSize = 65536;
|
||||
@ -28,6 +28,11 @@ class Buffer {
|
||||
public:
|
||||
Buffer();
|
||||
|
||||
Buffer(const Buffer &) = delete;
|
||||
Buffer(Buffer &&) = delete;
|
||||
Buffer &operator=(const Buffer &) = delete;
|
||||
Buffer &operator=(Buffer &&) = delete;
|
||||
|
||||
/**
|
||||
* This class provides all functions from the buffer that are needed to allow
|
||||
* reading data from the buffer.
|
||||
@ -36,6 +41,11 @@ class Buffer {
|
||||
public:
|
||||
ReadEnd(Buffer &buffer);
|
||||
|
||||
ReadEnd(const ReadEnd &) = delete;
|
||||
ReadEnd(ReadEnd &&) = delete;
|
||||
ReadEnd &operator=(const ReadEnd &) = delete;
|
||||
ReadEnd &operator=(ReadEnd &&) = delete;
|
||||
|
||||
uint8_t *data();
|
||||
|
||||
size_t size() const;
|
||||
@ -58,6 +68,11 @@ class Buffer {
|
||||
public:
|
||||
WriteEnd(Buffer &buffer);
|
||||
|
||||
WriteEnd(const WriteEnd &) = delete;
|
||||
WriteEnd(WriteEnd &&) = delete;
|
||||
WriteEnd &operator=(const WriteEnd &) = delete;
|
||||
WriteEnd &operator=(WriteEnd &&) = delete;
|
||||
|
||||
io::network::StreamBuffer Allocate();
|
||||
|
||||
void Written(size_t len);
|
||||
|
232
src/communication/client.cpp
Normal file
232
src/communication/client.cpp
Normal file
@ -0,0 +1,232 @@
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/client.hpp"
|
||||
#include "communication/helpers.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
Client::Client(ClientContext *context) : context_(context) {}
|
||||
|
||||
Client::~Client() {
|
||||
Close();
|
||||
ReleaseSslObjects();
|
||||
}
|
||||
|
||||
bool Client::Connect(const io::network::Endpoint &endpoint) {
|
||||
// Try to establish a socket connection.
|
||||
if (!socket_.Connect(endpoint)) return false;
|
||||
|
||||
// Enable TCP keep alive for all connections.
|
||||
// Because we manually always set the `have_more` flag to the socket
|
||||
// `Write` call we can disable the Nagle algorithm because we know that we
|
||||
// are always sending optimal packets. Even if we don't send optimal
|
||||
// packets, there will be no delay between packets and throughput won't
|
||||
// suffer.
|
||||
socket_.SetKeepAlive();
|
||||
socket_.SetNoDelay();
|
||||
|
||||
if (context_->use_ssl()) {
|
||||
// Release leftover SSL objects.
|
||||
ReleaseSslObjects();
|
||||
|
||||
// Create a new SSL object that will be used for SSL communication.
|
||||
ssl_ = SSL_new(context_->context());
|
||||
if (ssl_ == nullptr) {
|
||||
LOG(WARNING) << "Couldn't create client SSL object!";
|
||||
socket_.Close();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
|
||||
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
|
||||
// it doesn't need to close the socket when destructing all objects (we
|
||||
// handle that in our socket destructor).
|
||||
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
|
||||
if (bio_ == nullptr) {
|
||||
LOG(WARNING) << "Couldn't create client BIO object!";
|
||||
socket_.Close();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connect the BIO object to the SSL object so that OpenSSL knows which
|
||||
// stream it should use for communication. We use the same object for both
|
||||
// the read and write end. This function cannot fail.
|
||||
SSL_set_bio(ssl_, bio_, bio_);
|
||||
|
||||
// Clear all leftover errors.
|
||||
ERR_clear_error();
|
||||
|
||||
// Perform the TLS handshake.
|
||||
auto ret = SSL_connect(ssl_);
|
||||
if (ret != 1) {
|
||||
LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
|
||||
socket_.Close();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Client::ErrorStatus() { return socket_.ErrorStatus(); }
|
||||
|
||||
void Client::Shutdown() { socket_.Shutdown(); }
|
||||
|
||||
void Client::Close() {
|
||||
if (ssl_) {
|
||||
// Perform an unidirectional SSL shutdown. That just means that we send
|
||||
// the shutdown message and don't wait or care for the result.
|
||||
SSL_shutdown(ssl_);
|
||||
}
|
||||
socket_.Close();
|
||||
}
|
||||
|
||||
bool Client::Read(size_t len) {
|
||||
size_t received = 0;
|
||||
buffer_.write_end().Resize(buffer_.read_end().size() + len);
|
||||
while (received < len) {
|
||||
auto buff = buffer_.write_end().Allocate();
|
||||
if (ssl_) {
|
||||
// We clear errors here to prevent errors piling up in the internal
|
||||
// OpenSSL error queue. To see when could that be an issue read this:
|
||||
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||
ERR_clear_error();
|
||||
|
||||
// Read encrypted data from the socket using OpenSSL.
|
||||
auto got = SSL_read(ssl_, buff.data, len - received);
|
||||
|
||||
// Handle errors that might have occurred.
|
||||
if (got < 0) {
|
||||
auto err = SSL_get_error(ssl_, got);
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
// OpenSSL want's to read more data from the socket. We wait for
|
||||
// more data to be ready and retry the call.
|
||||
socket_.WaitForReadyRead();
|
||||
continue;
|
||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||
// The OpenSSL library probably wants to perform some kind of
|
||||
// handshake so we wait for the socket to become ready for a write
|
||||
// and call the read again.
|
||||
socket_.WaitForReadyWrite();
|
||||
continue;
|
||||
} else {
|
||||
// This is a fatal error.
|
||||
LOG(WARNING) << "Received an unexpected SSL error: " << err;
|
||||
return false;
|
||||
}
|
||||
} else if (got == 0) {
|
||||
// The server closed the connection.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Notify the buffer that it has new data.
|
||||
buffer_.write_end().Written(got);
|
||||
received += got;
|
||||
} else {
|
||||
// Read raw data from the socket.
|
||||
auto got = socket_.Read(buff.data, len - received);
|
||||
|
||||
if (got <= 0) {
|
||||
// If `read` returns 0 the server has closed the connection. If `read`
|
||||
// returns -1 all of the errors that could be found in `errno` are
|
||||
// fatal errors (because we are using a blocking socket) so return a
|
||||
// read failure.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Notify the buffer that it has new data.
|
||||
buffer_.write_end().Written(got);
|
||||
received += got;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint8_t *Client::GetData() { return buffer_.read_end().data(); }
|
||||
|
||||
size_t Client::GetDataSize() { return buffer_.read_end().size(); }
|
||||
|
||||
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); }
|
||||
|
||||
void Client::ClearData() { buffer_.read_end().Clear(); }
|
||||
|
||||
bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
if (ssl_) {
|
||||
// `SSL_write` has the interface of a normal `write` call. Because of that
|
||||
// we need to ensure that all data is written to the socket manually.
|
||||
while (len > 0) {
|
||||
// We clear errors here to prevent errors piling up in the internal
|
||||
// OpenSSL error queue. To see when could that be an issue read this:
|
||||
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||
ERR_clear_error();
|
||||
|
||||
// Write data to the socket using OpenSSL.
|
||||
auto written = SSL_write(ssl_, data, len);
|
||||
if (written < 0) {
|
||||
auto err = SSL_get_error(ssl_, written);
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
// OpenSSL wants to perform some kind of handshake, we need to
|
||||
// ensure that there is data available for the next call to
|
||||
// `SSL_write`.
|
||||
socket_.WaitForReadyRead();
|
||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||
// The socket probably returned WOULDBLOCK and we need to wait for
|
||||
// the output buffers to clear and reattempt the send.
|
||||
socket_.WaitForReadyWrite();
|
||||
} else {
|
||||
// This is a fatal error.
|
||||
return false;
|
||||
}
|
||||
} else if (written == 0) {
|
||||
// The client closed the connection.
|
||||
return false;
|
||||
} else {
|
||||
len -= written;
|
||||
data += written;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return socket_.Write(data, len, have_more);
|
||||
}
|
||||
}
|
||||
|
||||
bool Client::Write(const std::string &str, bool have_more) {
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
|
||||
have_more);
|
||||
}
|
||||
|
||||
const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); }
|
||||
|
||||
void Client::ReleaseSslObjects() {
|
||||
// If we are using SSL we need to free the allocated objects. Here we only
|
||||
// free the SSL object because the `SSL_free` function also automatically
|
||||
// frees the BIO object.
|
||||
if (ssl_) {
|
||||
SSL_free(ssl_);
|
||||
ssl_ = nullptr;
|
||||
bio_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ClientInputStream::ClientInputStream(Client &client) : client_(client) {}
|
||||
|
||||
uint8_t *ClientInputStream::data() { return client_.GetData(); }
|
||||
|
||||
size_t ClientInputStream::size() const { return client_.GetDataSize(); }
|
||||
|
||||
void ClientInputStream::Shift(size_t len) { client_.ShiftData(len); }
|
||||
|
||||
void ClientInputStream::Clear() { client_.ClearData(); }
|
||||
|
||||
ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {}
|
||||
|
||||
bool ClientOutputStream::Write(const uint8_t *data, size_t len,
|
||||
bool have_more) {
|
||||
return client_.Write(data, len, have_more);
|
||||
}
|
||||
bool ClientOutputStream::Write(const std::string &str, bool have_more) {
|
||||
return client_.Write(str, have_more);
|
||||
}
|
||||
|
||||
} // namespace communication
|
@ -1,6 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "communication/buffer.hpp"
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/init.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
|
||||
@ -10,106 +16,115 @@ namespace communication {
|
||||
* This class implements a generic network Client.
|
||||
* It uses blocking sockets and provides an API that can be used to receive/send
|
||||
* data over the network connection.
|
||||
*
|
||||
* NOTE: If you use this client you **must** call `communication::Init()` from
|
||||
* the `main` function before using the client!
|
||||
*/
|
||||
class Client {
|
||||
class Client final {
|
||||
public:
|
||||
explicit Client(ClientContext *context);
|
||||
|
||||
~Client();
|
||||
|
||||
Client(const Client &) = delete;
|
||||
Client(Client &&) = delete;
|
||||
Client &operator=(const Client &) = delete;
|
||||
Client &operator=(Client &&) = delete;
|
||||
|
||||
/**
|
||||
* This function connects to a remote server and returns whether the connect
|
||||
* succeeded.
|
||||
*/
|
||||
bool Connect(const io::network::Endpoint &endpoint) {
|
||||
if (!socket_.Connect(endpoint)) return false;
|
||||
socket_.SetKeepAlive();
|
||||
socket_.SetNoDelay();
|
||||
return true;
|
||||
}
|
||||
bool Connect(const io::network::Endpoint &endpoint);
|
||||
|
||||
/**
|
||||
* This function returns `true` if the socket is in an error state.
|
||||
*/
|
||||
bool ErrorStatus() { return socket_.ErrorStatus(); }
|
||||
bool ErrorStatus();
|
||||
|
||||
/**
|
||||
* This function shuts down the socket.
|
||||
*/
|
||||
void Shutdown() { socket_.Shutdown(); }
|
||||
void Shutdown();
|
||||
|
||||
/**
|
||||
* This function closes the socket.
|
||||
*/
|
||||
void Close() { socket_.Close(); }
|
||||
void Close();
|
||||
|
||||
/**
|
||||
* This function is used to receive `len` bytes from the socket and stores it
|
||||
* in an internal buffer. It returns `true` if the read succeeded and `false`
|
||||
* if it didn't.
|
||||
*/
|
||||
bool Read(size_t len) {
|
||||
size_t received = 0;
|
||||
buffer_.write_end().Resize(buffer_.read_end().size() + len);
|
||||
while (received < len) {
|
||||
auto buff = buffer_.write_end().Allocate();
|
||||
int got = socket_.Read(buff.data, len - received);
|
||||
if (got <= 0) return false;
|
||||
buffer_.write_end().Written(got);
|
||||
received += got;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Read(size_t len);
|
||||
|
||||
/**
|
||||
* This function returns a pointer to the read data that is currently stored
|
||||
* in the client.
|
||||
*/
|
||||
uint8_t *GetData() { return buffer_.read_end().data(); }
|
||||
uint8_t *GetData();
|
||||
|
||||
/**
|
||||
* This function returns the size of the read data that is currently stored in
|
||||
* the client.
|
||||
*/
|
||||
size_t GetDataSize() { return buffer_.read_end().size(); }
|
||||
size_t GetDataSize();
|
||||
|
||||
/**
|
||||
* This function removes first `len` bytes from the data buffer.
|
||||
*/
|
||||
void ShiftData(size_t len) { buffer_.read_end().Shift(len); }
|
||||
void ShiftData(size_t len);
|
||||
|
||||
/**
|
||||
* This function clears the data buffer.
|
||||
*/
|
||||
void ClearData() { buffer_.read_end().Clear(); }
|
||||
void ClearData();
|
||||
|
||||
// Write end
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||
return socket_.Write(data, len, have_more);
|
||||
}
|
||||
bool Write(const std::string &str, bool have_more = false) {
|
||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
|
||||
have_more);
|
||||
}
|
||||
/**
|
||||
* This function writes data to the socket.
|
||||
* TODO (mferencevic): the `have_more` flag currently isn't supported when
|
||||
* using OpenSSL
|
||||
*/
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false);
|
||||
|
||||
const io::network::Endpoint &endpoint() { return socket_.endpoint(); }
|
||||
/**
|
||||
* This function writes data to the socket.
|
||||
*/
|
||||
bool Write(const std::string &str, bool have_more = false);
|
||||
|
||||
const io::network::Endpoint &endpoint();
|
||||
|
||||
private:
|
||||
io::network::Socket socket_;
|
||||
void ReleaseSslObjects();
|
||||
|
||||
io::network::Socket socket_;
|
||||
Buffer buffer_;
|
||||
|
||||
ClientContext *context_;
|
||||
SSL *ssl_{nullptr};
|
||||
BIO *bio_{nullptr};
|
||||
};
|
||||
|
||||
/**
|
||||
* This class provides a stream-like input side object to the client.
|
||||
*/
|
||||
class ClientInputStream {
|
||||
class ClientInputStream final {
|
||||
public:
|
||||
ClientInputStream(Client &client) : client_(client) {}
|
||||
ClientInputStream(Client &client);
|
||||
|
||||
uint8_t *data() { return client_.GetData(); }
|
||||
ClientInputStream(const ClientInputStream &) = delete;
|
||||
ClientInputStream(ClientInputStream &&) = delete;
|
||||
ClientInputStream &operator=(const ClientInputStream &) = delete;
|
||||
ClientInputStream &operator=(ClientInputStream &&) = delete;
|
||||
|
||||
size_t size() const { return client_.GetDataSize(); }
|
||||
uint8_t *data();
|
||||
|
||||
void Shift(size_t len) { client_.ShiftData(len); }
|
||||
size_t size() const;
|
||||
|
||||
void Clear() { client_.ClearData(); }
|
||||
void Shift(size_t len);
|
||||
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
Client &client_;
|
||||
@ -118,16 +133,18 @@ class ClientInputStream {
|
||||
/**
|
||||
* This class provides a stream-like output side object to the client.
|
||||
*/
|
||||
class ClientOutputStream {
|
||||
class ClientOutputStream final {
|
||||
public:
|
||||
ClientOutputStream(Client &client) : client_(client) {}
|
||||
ClientOutputStream(Client &client);
|
||||
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||
return client_.Write(data, len, have_more);
|
||||
}
|
||||
bool Write(const std::string &str, bool have_more = false) {
|
||||
return client_.Write(str, have_more);
|
||||
}
|
||||
ClientOutputStream(const ClientOutputStream &) = delete;
|
||||
ClientOutputStream(ClientOutputStream &&) = delete;
|
||||
ClientOutputStream &operator=(const ClientOutputStream &) = delete;
|
||||
ClientOutputStream &operator=(ClientOutputStream &&) = delete;
|
||||
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false);
|
||||
|
||||
bool Write(const std::string &str, bool have_more = false);
|
||||
|
||||
private:
|
||||
Client &client_;
|
||||
|
80
src/communication/context.cpp
Normal file
80
src/communication/context.cpp
Normal file
@ -0,0 +1,80 @@
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/context.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
|
||||
if (use_ssl_) {
|
||||
ctx_ = SSL_CTX_new(TLS_client_method());
|
||||
CHECK(ctx_ != nullptr) << "Couldn't create client SSL_CTX object!";
|
||||
|
||||
// Disable legacy SSL support. Other options can be seen here:
|
||||
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
|
||||
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
|
||||
}
|
||||
}
|
||||
|
||||
ClientContext::ClientContext(const std::string &key_file,
|
||||
const std::string &cert_file)
|
||||
: ClientContext(true) {
|
||||
if (key_file != "" && cert_file != "") {
|
||||
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1)
|
||||
<< "Couldn't load client certificate from file: " << cert_file;
|
||||
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1)
|
||||
<< "Couldn't load client private key from file: " << key_file;
|
||||
}
|
||||
}
|
||||
|
||||
SSL_CTX *ClientContext::context() { return ctx_; }
|
||||
|
||||
bool ClientContext::use_ssl() { return use_ssl_; }
|
||||
|
||||
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
|
||||
|
||||
ServerContext::ServerContext(const std::string &key_file,
|
||||
const std::string &cert_file,
|
||||
const std::string &ca_file, bool verify_peer)
|
||||
: use_ssl_(true), ctx_(SSL_CTX_new(TLS_server_method())) {
|
||||
// TODO (mferencevic): add support for encrypted private keys
|
||||
// TODO (mferencevic): add certificate revocation list (CRL)
|
||||
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
|
||||
SSL_FILETYPE_PEM) == 1)
|
||||
<< "Couldn't load server certificate from file: " << cert_file;
|
||||
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) ==
|
||||
1)
|
||||
<< "Couldn't load server private key from file: " << key_file;
|
||||
|
||||
// Disable legacy SSL support. Other options can be seen here:
|
||||
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
|
||||
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
|
||||
|
||||
if (ca_file != "") {
|
||||
// Load the certificate authority file.
|
||||
CHECK(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1)
|
||||
<< "Couldn't load certificate authority from file: " << ca_file;
|
||||
|
||||
if (verify_peer) {
|
||||
// Add the CA to list of accepted CAs that is sent to the client.
|
||||
STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str());
|
||||
CHECK(ca_names != nullptr)
|
||||
<< "Couldn't load certificate authority from file: " << ca_file;
|
||||
// `ca_names` doesn' need to be free'd because we pass it to
|
||||
// `SSL_CTX_set_client_CA_list`:
|
||||
// https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html
|
||||
SSL_CTX_set_client_CA_list(ctx_, ca_names);
|
||||
|
||||
// Enable verification of the client certificate.
|
||||
SSL_CTX_set_verify(
|
||||
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SSL_CTX *ServerContext::context() { return ctx_; }
|
||||
|
||||
bool ServerContext::use_ssl() { return use_ssl_; }
|
||||
|
||||
} // namespace communication
|
63
src/communication/context.hpp
Normal file
63
src/communication/context.hpp
Normal file
@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
namespace communication {
|
||||
|
||||
class ClientContext final {
|
||||
public:
|
||||
/**
|
||||
* This constructor constructs a ClientContext that can either not use SSL
|
||||
* (`use_ssl` is `false` by default), or it constructs a ClientContext that
|
||||
* doesn't use a client certificate when `use_ssl` is set to `true`.
|
||||
*/
|
||||
explicit ClientContext(bool use_ssl = false);
|
||||
|
||||
/**
|
||||
* This constructor constructs a ClientContext that uses SSL and uses the
|
||||
* specific client private key and certificate combination. If the parameters
|
||||
* `key_file` and `cert_file` are equal to "" then the constructor falls back
|
||||
* to the above constructor that uses SSL without certificates.
|
||||
*/
|
||||
ClientContext(const std::string &key_file, const std::string &cert_file);
|
||||
|
||||
SSL_CTX *context();
|
||||
|
||||
bool use_ssl();
|
||||
|
||||
private:
|
||||
bool use_ssl_;
|
||||
SSL_CTX *ctx_;
|
||||
};
|
||||
|
||||
class ServerContext final {
|
||||
public:
|
||||
/**
|
||||
* This constructor constructs a ServerContext that doesn't use SSL.
|
||||
*/
|
||||
ServerContext();
|
||||
|
||||
/**
|
||||
* This constructor constructs a ServerContext that uses SSL. The parameters
|
||||
* `key_file` and `cert_file` can't be "" because when setting up a server it
|
||||
* is mandatory to supply a private key and certificate. The parameter
|
||||
* `ca_file` can be "" because SSL doesn't necessarily need to check that the
|
||||
* client has a valid certificate. If you specify `verify_peer` to be `true`
|
||||
* to check that the client certificate is valid, then you need to supply a
|
||||
* valid `ca_file` as well.
|
||||
*/
|
||||
ServerContext(const std::string &key_file, const std::string &cert_file,
|
||||
const std::string &ca_file = "", bool verify_peer = false);
|
||||
|
||||
SSL_CTX *context();
|
||||
|
||||
bool use_ssl();
|
||||
|
||||
private:
|
||||
bool use_ssl_;
|
||||
SSL_CTX *ctx_;
|
||||
};
|
||||
|
||||
} // namespace communication
|
13
src/communication/helpers.cpp
Normal file
13
src/communication/helpers.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <openssl/err.h>
|
||||
|
||||
#include "communication/helpers.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
const std::string SslGetLastError() {
|
||||
char buff[2048];
|
||||
auto err = ERR_get_error();
|
||||
ERR_error_string_n(err, buff, sizeof(buff));
|
||||
return std::string(buff);
|
||||
}
|
||||
} // namespace communication
|
12
src/communication/helpers.hpp
Normal file
12
src/communication/helpers.hpp
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace communication {
|
||||
|
||||
/**
|
||||
* This function reads and returns a string describing the last OpenSSL error.
|
||||
*/
|
||||
const std::string SslGetLastError();
|
||||
|
||||
} // namespace communication
|
21
src/communication/init.cpp
Normal file
21
src/communication/init.cpp
Normal file
@ -0,0 +1,21 @@
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "utils/signals.hpp"
|
||||
|
||||
namespace communication {
|
||||
|
||||
void Init() {
|
||||
// Initialize the OpenSSL library.
|
||||
SSL_library_init();
|
||||
OpenSSL_add_ssl_algorithms();
|
||||
SSL_load_error_strings();
|
||||
ERR_load_crypto_strings();
|
||||
|
||||
// Ignore SIGPIPE.
|
||||
CHECK(utils::SignalIgnore(utils::Signal::Pipe)) << "Couldn't ignore SIGPIPE!";
|
||||
}
|
||||
} // namespace communication
|
16
src/communication/init.hpp
Normal file
16
src/communication/init.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
namespace communication {
|
||||
|
||||
/**
|
||||
* Call this function in each `main` file that uses the Communication stack. It
|
||||
* is used to initialize all libraries (primarily OpenSSL) and to fix some
|
||||
* issues also related to OpenSSL (handling of SIGPIPE).
|
||||
*
|
||||
* Description of OpenSSL init can be seen here:
|
||||
* https://wiki.openssl.org/index.php/Library_Initialization
|
||||
*
|
||||
* NOTE: This function must be called **exactly** once.
|
||||
*/
|
||||
void Init();
|
||||
} // namespace communication
|
@ -13,6 +13,7 @@
|
||||
#include "communication/session.hpp"
|
||||
#include "io/network/epoll.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "utils/signals.hpp"
|
||||
#include "utils/thread.hpp"
|
||||
#include "utils/thread/sync.hpp"
|
||||
|
||||
@ -28,7 +29,7 @@ namespace communication {
|
||||
* expired.
|
||||
*/
|
||||
template <class TSession, class TSessionData>
|
||||
class Listener {
|
||||
class Listener final {
|
||||
private:
|
||||
// The maximum number of events handled per execution thread is 1. This is
|
||||
// because each event represents the start of a network request and it doesn't
|
||||
@ -39,14 +40,27 @@ class Listener {
|
||||
using SessionHandler = Session<TSession, TSessionData>;
|
||||
|
||||
public:
|
||||
Listener(TSessionData &data, int inactivity_timeout_sec,
|
||||
const std::string &service_name)
|
||||
Listener(TSessionData &data, ServerContext *context,
|
||||
int inactivity_timeout_sec, const std::string &service_name,
|
||||
size_t workers_count)
|
||||
: data_(data),
|
||||
alive_(true),
|
||||
context_(context),
|
||||
inactivity_timeout_sec_(inactivity_timeout_sec),
|
||||
service_name_(service_name) {
|
||||
std::cout << "Starting " << workers_count << " " << service_name_
|
||||
<< " workers" << std::endl;
|
||||
for (size_t i = 0; i < workers_count; ++i) {
|
||||
worker_threads_.emplace_back([this, service_name, i]() {
|
||||
utils::ThreadSetName(fmt::format("{} worker {}", service_name, i + 1));
|
||||
while (alive_) {
|
||||
WaitAndProcessEvents();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (inactivity_timeout_sec_ > 0) {
|
||||
thread_ = std::thread([this, service_name]() {
|
||||
timeout_thread_ = std::thread([this, service_name]() {
|
||||
utils::ThreadSetName(fmt::format("{} timeout", service_name));
|
||||
while (alive_) {
|
||||
{
|
||||
@ -72,7 +86,10 @@ class Listener {
|
||||
|
||||
~Listener() {
|
||||
alive_.store(false);
|
||||
if (thread_.joinable()) thread_.join();
|
||||
if (timeout_thread_.joinable()) timeout_thread_.join();
|
||||
for (auto &worker_thread : worker_threads_) {
|
||||
worker_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
Listener(const Listener &) = delete;
|
||||
@ -88,20 +105,12 @@ class Listener {
|
||||
void AddConnection(io::network::Socket &&connection) {
|
||||
std::unique_lock<utils::SpinLock> guard(lock_);
|
||||
|
||||
// Set connection options.
|
||||
// The socket is left to be a blocking socket, but when `Read` is called
|
||||
// then a flag is manually set to enable non-blocking read that is used in
|
||||
// conjunction with `EPOLLET`. That means that the socket is used in a
|
||||
// non-blocking fashion for reads and a blocking fashion for writes.
|
||||
connection.SetKeepAlive();
|
||||
connection.SetNoDelay();
|
||||
|
||||
// Remember fd before moving connection into Session.
|
||||
int fd = connection.fd();
|
||||
|
||||
// Create a new Session for the connection.
|
||||
sessions_.push_back(std::make_unique<SessionHandler>(
|
||||
std::move(connection), data_, inactivity_timeout_sec_));
|
||||
std::move(connection), data_, context_, inactivity_timeout_sec_));
|
||||
|
||||
// Register the connection in Epoll.
|
||||
// We want to listen to an incoming event which is edge triggered and
|
||||
@ -113,6 +122,7 @@ class Listener {
|
||||
sessions_.back().get());
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* This function polls the event queue and processes incoming data.
|
||||
* It is thread safe and is intended to be called from multiple threads and
|
||||
@ -168,7 +178,6 @@ class Listener {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool ExecuteSession(SessionHandler &session) {
|
||||
try {
|
||||
if (session.Execute()) {
|
||||
@ -223,8 +232,11 @@ class Listener {
|
||||
utils::SpinLock lock_;
|
||||
std::vector<std::unique_ptr<SessionHandler>> sessions_;
|
||||
|
||||
std::thread thread_;
|
||||
std::thread timeout_thread_;
|
||||
std::vector<std::thread> worker_threads_;
|
||||
std::atomic<bool> alive_;
|
||||
|
||||
ServerContext *context_;
|
||||
const int inactivity_timeout_sec_;
|
||||
const std::string service_name_;
|
||||
};
|
||||
|
@ -30,7 +30,7 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
|
||||
|
||||
// Connect to the remote server.
|
||||
if (!client_) {
|
||||
client_.emplace();
|
||||
client_.emplace(&context_);
|
||||
if (!client_->Connect(endpoint_)) {
|
||||
LOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
|
||||
client_ = std::experimental::nullopt;
|
||||
|
@ -86,6 +86,8 @@ class Client {
|
||||
::capnp::MessageBuilder *message);
|
||||
|
||||
io::network::Endpoint endpoint_;
|
||||
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
|
||||
communication::ClientContext context_;
|
||||
std::experimental::optional<communication::Client> client_;
|
||||
|
||||
std::mutex mutex_;
|
||||
|
@ -4,7 +4,7 @@ namespace communication::rpc {
|
||||
|
||||
Server::Server(const io::network::Endpoint &endpoint,
|
||||
size_t workers_count)
|
||||
: server_(endpoint, *this, -1, "RPC", workers_count) {}
|
||||
: server_(endpoint, *this, &context_, -1, "RPC", workers_count) {}
|
||||
|
||||
void Server::StopProcessingCalls() {
|
||||
server_.Shutdown();
|
||||
|
@ -78,6 +78,8 @@ class Server {
|
||||
ConcurrentMap<uint64_t, RpcCallback> callbacks_;
|
||||
|
||||
std::mutex mutex_;
|
||||
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
|
||||
communication::ServerContext context_;
|
||||
communication::Server<Session, Server> server_;
|
||||
}; // namespace communication::rpc
|
||||
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <fmt/format.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/init.hpp"
|
||||
#include "communication/listener.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "utils/thread.hpp"
|
||||
@ -27,6 +28,9 @@ namespace communication {
|
||||
* Current Server achitecture:
|
||||
* incoming connection -> server -> listener -> session
|
||||
*
|
||||
* NOTE: If you use this server you **must** call `communication::Init()` from
|
||||
* the `main` function before using the server!
|
||||
*
|
||||
* @tparam TSession the server can handle different Sessions, each session
|
||||
* represents a different protocol so the same network infrastructure
|
||||
* can be used for handling different protocols
|
||||
@ -34,7 +38,7 @@ namespace communication {
|
||||
* session
|
||||
*/
|
||||
template <typename TSession, typename TSessionData>
|
||||
class Server {
|
||||
class Server final {
|
||||
public:
|
||||
using Socket = io::network::Socket;
|
||||
|
||||
@ -43,9 +47,11 @@ class Server {
|
||||
* invokes workers_count workers
|
||||
*/
|
||||
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
|
||||
int inactivity_timeout_sec, const std::string &service_name,
|
||||
ServerContext *context, int inactivity_timeout_sec,
|
||||
const std::string &service_name,
|
||||
size_t workers_count = std::thread::hardware_concurrency())
|
||||
: listener_(session_data, inactivity_timeout_sec, service_name),
|
||||
: listener_(session_data, context, inactivity_timeout_sec, service_name,
|
||||
workers_count),
|
||||
service_name_(service_name) {
|
||||
// Without server we can't continue with application so we can just
|
||||
// terminate here.
|
||||
@ -58,18 +64,7 @@ class Server {
|
||||
}
|
||||
|
||||
thread_ = std::thread([this, workers_count, service_name]() {
|
||||
std::cout << "Starting " << workers_count << " " << service_name
|
||||
<< " workers" << std::endl;
|
||||
utils::ThreadSetName(fmt::format("{} server", service_name));
|
||||
for (size_t i = 0; i < workers_count; ++i) {
|
||||
worker_threads_.emplace_back([this, service_name, i]() {
|
||||
utils::ThreadSetName(
|
||||
fmt::format("{} worker {}", service_name, i + 1));
|
||||
while (alive_) {
|
||||
listener_.WaitAndProcessEvents();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::cout << service_name << " server is fully armed and operational"
|
||||
<< std::endl;
|
||||
@ -81,9 +76,6 @@ class Server {
|
||||
}
|
||||
|
||||
std::cout << service_name << " shutting down..." << std::endl;
|
||||
for (auto &worker_thread : worker_threads_) {
|
||||
worker_thread.join();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -128,7 +120,6 @@ class Server {
|
||||
|
||||
std::atomic<bool> alive_{true};
|
||||
std::thread thread_;
|
||||
std::vector<std::thread> worker_threads_;
|
||||
|
||||
Socket socket_;
|
||||
Listener<TSession, TSessionData> listener_;
|
||||
|
@ -9,7 +9,13 @@
|
||||
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/ssl.h>
|
||||
|
||||
#include "communication/buffer.hpp"
|
||||
#include "communication/context.hpp"
|
||||
#include "communication/helpers.hpp"
|
||||
#include "io/network/socket.hpp"
|
||||
#include "io/network/stream_buffer.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
@ -35,12 +41,19 @@ using InputStream = Buffer::ReadEnd;
|
||||
* This is used to provide output from user sessions. All sessions used with the
|
||||
* network stack should use this class for their output stream.
|
||||
*/
|
||||
class OutputStream {
|
||||
class OutputStream final {
|
||||
public:
|
||||
OutputStream(io::network::Socket &socket) : socket_(socket) {}
|
||||
OutputStream(
|
||||
std::function<bool(const uint8_t *, size_t, bool)> write_function)
|
||||
: write_function_(write_function) {}
|
||||
|
||||
OutputStream(const OutputStream &) = delete;
|
||||
OutputStream(OutputStream &&) = delete;
|
||||
OutputStream &operator=(const OutputStream &) = delete;
|
||||
OutputStream &operator=(OutputStream &&) = delete;
|
||||
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||
return socket_.Write(data, len, have_more);
|
||||
return write_function_(data, len, have_more);
|
||||
}
|
||||
|
||||
bool Write(const std::string &str, bool have_more = false) {
|
||||
@ -49,7 +62,7 @@ class OutputStream {
|
||||
}
|
||||
|
||||
private:
|
||||
io::network::Socket &socket_;
|
||||
std::function<bool(const uint8_t *, size_t, bool)> write_function_;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -58,20 +71,73 @@ class OutputStream {
|
||||
* wrapping.
|
||||
*/
|
||||
template <class TSession, class TSessionData>
|
||||
class Session {
|
||||
class Session final {
|
||||
public:
|
||||
Session(io::network::Socket &&socket, TSessionData &data,
|
||||
int inactivity_timeout_sec)
|
||||
ServerContext *context, int inactivity_timeout_sec)
|
||||
: socket_(std::move(socket)),
|
||||
output_stream_(socket_),
|
||||
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
|
||||
return Write(data, len, have_more);
|
||||
}),
|
||||
session_(data, input_buffer_.read_end(), output_stream_),
|
||||
inactivity_timeout_sec_(inactivity_timeout_sec) {}
|
||||
inactivity_timeout_sec_(inactivity_timeout_sec) {
|
||||
// Set socket options.
|
||||
// The socket is set to be a non-blocking socket. We use the socket in a
|
||||
// non-blocking fashion for reads and manually simulate a blocking socket
|
||||
// type for writes. This manual handling of writes is necessary because
|
||||
// OpenSSL doesn't provide a way to add `recv` parameters to the `SSL_read`
|
||||
// call so we can't have a blocking socket and use it in a non-blocking way
|
||||
// only for reads.
|
||||
// Keep alive is enabled so that the Kernel's TCP stack notifies us if a
|
||||
// connection is broken and shouldn't be used anymore.
|
||||
// Because we manually always set the `have_more` flag to the socket
|
||||
// `Write` call we can disable the Nagle algorithm because we know that we
|
||||
// are always sending optimal packets. Even if we don't send optimal
|
||||
// packets, there will be no delay between packets and throughput won't
|
||||
// suffer.
|
||||
socket_.SetNonBlocking();
|
||||
socket_.SetKeepAlive();
|
||||
socket_.SetNoDelay();
|
||||
|
||||
// Prepare SSL if we should be using it.
|
||||
if (context->use_ssl()) {
|
||||
// Create a new SSL object that will be used for SSL communication.
|
||||
ssl_ = SSL_new(context->context());
|
||||
CHECK(ssl_ != nullptr) << "Couldn't create server SSL object!";
|
||||
|
||||
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
|
||||
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
|
||||
// it doesn't need to close the socket when destructing all objects (we
|
||||
// handle that in our socket destructor).
|
||||
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
|
||||
CHECK(bio_ != nullptr) << "Couldn't create server BIO object!";
|
||||
|
||||
// Connect the BIO object to the SSL object so that OpenSSL knows which
|
||||
// stream it should use for communication. We use the same object for both
|
||||
// the read and write end. This function cannot fail.
|
||||
SSL_set_bio(ssl_, bio_, bio_);
|
||||
|
||||
// Indicate to OpenSSL that this connection is a server. The TLS handshake
|
||||
// will be performed in the first `SSL_read` or `SSL_write` call. This
|
||||
// function cannot fail.
|
||||
SSL_set_accept_state(ssl_);
|
||||
}
|
||||
}
|
||||
|
||||
Session(const Session &) = delete;
|
||||
Session(Session &&) = delete;
|
||||
Session &operator=(const Session &) = delete;
|
||||
Session &operator=(Session &&) = delete;
|
||||
|
||||
~Session() {
|
||||
// If we are using SSL we need to free the allocated objects. Here we only
|
||||
// free the SSL object because the `SSL_free` function also automatically
|
||||
// frees the BIO object.
|
||||
if (ssl_) {
|
||||
SSL_free(ssl_);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This function is called from the communication stack when an event occurs
|
||||
* indicating that there is data waiting to be read. This function calls the
|
||||
@ -87,29 +153,69 @@ class Session {
|
||||
|
||||
// Allocate the buffer to fill the data.
|
||||
auto buf = input_buffer_.write_end().Allocate();
|
||||
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
|
||||
int len = socket_.Read(buf.data, buf.len, true);
|
||||
|
||||
// Check for read errors.
|
||||
if (len == -1) {
|
||||
// This means read would block or read was interrupted by signal, we
|
||||
// return `true` to indicate that all data is processad and to stop
|
||||
// reading of data.
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
|
||||
return true;
|
||||
if (ssl_) {
|
||||
// We clear errors here to prevent errors piling up in the internal
|
||||
// OpenSSL error queue. To see when could that be an issue read this:
|
||||
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||
ERR_clear_error();
|
||||
|
||||
// Read data from the socket using the OpenSSL API.
|
||||
auto len = SSL_read(ssl_, buf.data, buf.len);
|
||||
|
||||
// Check for read errors.
|
||||
if (len < 0) {
|
||||
auto err = SSL_get_error(ssl_, len);
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
// OpenSSL want's to read more data from the socket. We return `true`
|
||||
// to stop execution of the session to wait for more data to be
|
||||
// received.
|
||||
return true;
|
||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||
// The OpenSSL library wants to perfrom some kind of handshake so we
|
||||
// wait for the socket to become ready for a write and call the read
|
||||
// again. We return `false` so that the listener calls this function
|
||||
// again.
|
||||
socket_.WaitForReadyWrite();
|
||||
return false;
|
||||
} else {
|
||||
// This is a fatal error.
|
||||
throw utils::BasicException(SslGetLastError());
|
||||
}
|
||||
} else if (len == 0) {
|
||||
// The client closed the connection.
|
||||
throw SessionClosedException("Session was closed by the client.");
|
||||
return false;
|
||||
} else {
|
||||
// Notify the input buffer that it has new data.
|
||||
input_buffer_.write_end().Written(len);
|
||||
}
|
||||
// Some other error occurred, throw an exception to start session cleanup.
|
||||
throw utils::BasicException("Couldn't read data from socket!");
|
||||
}
|
||||
} else {
|
||||
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
|
||||
// Note, the `true` parameter for non-blocking here is redundant because
|
||||
// the socket already is non-blocking.
|
||||
auto len = socket_.Read(buf.data, buf.len, true);
|
||||
|
||||
// The client has closed the connection.
|
||||
if (len == 0) {
|
||||
throw SessionClosedException("Session was closed by client.");
|
||||
// Check for read errors.
|
||||
if (len == -1) {
|
||||
// This means read would block or read was interrupted by signal, we
|
||||
// return `true` to indicate that all data is processad and to stop
|
||||
// reading of data.
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
|
||||
return true;
|
||||
}
|
||||
// Some other error occurred, throw an exception to start session
|
||||
// cleanup.
|
||||
throw utils::BasicException("Couldn't read data from the socket!");
|
||||
} else if (len == 0) {
|
||||
// The client has closed the connection.
|
||||
throw SessionClosedException("Session was closed by client.");
|
||||
} else {
|
||||
// Notify the input buffer that it has new data.
|
||||
input_buffer_.write_end().Written(len);
|
||||
}
|
||||
}
|
||||
|
||||
// Notify the input buffer that it has new data.
|
||||
input_buffer_.write_end().Written(len);
|
||||
|
||||
// Execute the session.
|
||||
session_.Execute();
|
||||
|
||||
@ -142,6 +248,52 @@ class Session {
|
||||
last_event_time_ = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
// TODO (mferencevic): the `have_more` flag currently isn't supported
|
||||
// when using OpenSSL
|
||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||
if (ssl_) {
|
||||
// `SSL_write` has the interface of a normal `write` call. Because of that
|
||||
// we need to ensure that all data is written to the socket manually.
|
||||
while (len > 0) {
|
||||
// We clear errors here to prevent errors piling up in the internal
|
||||
// OpenSSL error queue. To see when could that be an issue read this:
|
||||
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||
ERR_clear_error();
|
||||
|
||||
// Write data to the socket using OpenSSL.
|
||||
auto written = SSL_write(ssl_, data, len);
|
||||
if (written < 0) {
|
||||
auto err = SSL_get_error(ssl_, written);
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
// OpenSSL wants to perform some kind of handshake, we need to
|
||||
// ensure that there is data available for the next call to
|
||||
// `SSL_write`.
|
||||
socket_.WaitForReadyRead();
|
||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||
// The socket probably returned WOULDBLOCK and we need to wait for
|
||||
// the output buffers to clear and reattempt the send.
|
||||
socket_.WaitForReadyWrite();
|
||||
} else {
|
||||
// This is a fatal error.
|
||||
return false;
|
||||
}
|
||||
} else if (written == 0) {
|
||||
// The client closed the connection.
|
||||
return false;
|
||||
} else {
|
||||
len -= written;
|
||||
data += written;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
// This function guarantees that all data will be written to the socket
|
||||
// even if the socket is non-blocking. It will use a non-busy wait to send
|
||||
// all data.
|
||||
return socket_.Write(data, len, have_more);
|
||||
}
|
||||
}
|
||||
|
||||
// We own the socket.
|
||||
io::network::Socket socket_;
|
||||
|
||||
@ -157,5 +309,9 @@ class Session {
|
||||
std::chrono::steady_clock::now()};
|
||||
utils::SpinLock lock_;
|
||||
const int inactivity_timeout_sec_;
|
||||
};
|
||||
|
||||
// SSL objects.
|
||||
SSL *ssl_{nullptr};
|
||||
BIO *bio_{nullptr};
|
||||
}; // namespace communication
|
||||
} // namespace communication
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <poll.h>
|
||||
#include <sys/epoll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
@ -219,8 +220,13 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||
// Terminal error, return failure.
|
||||
return false;
|
||||
}
|
||||
// Non-fatal error, retry.
|
||||
continue;
|
||||
// Non-fatal error, retry after the socket is ready. This is here to
|
||||
// implement a non-busy wait. If we just continue with the loop we have a
|
||||
// busy wait.
|
||||
if (!WaitForReadyWrite()) return false;
|
||||
} else if (written == 0) {
|
||||
// The client closed the connection.
|
||||
return false;
|
||||
} else {
|
||||
len -= written;
|
||||
data += written;
|
||||
@ -234,7 +240,32 @@ bool Socket::Write(const std::string &s, bool have_more) {
|
||||
have_more);
|
||||
}
|
||||
|
||||
int Socket::Read(void *buffer, size_t len, bool nonblock) {
|
||||
ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) {
|
||||
return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0);
|
||||
}
|
||||
|
||||
bool Socket::WaitForReadyRead() {
|
||||
struct pollfd p;
|
||||
p.fd = socket_;
|
||||
p.events = POLLIN;
|
||||
// We call poll with one element in the poll fds array (first and second
|
||||
// arguments), also we set the timeout to -1 to block indefinitely until an
|
||||
// event occurs.
|
||||
int ret = poll(&p, 1, -1);
|
||||
if (ret < 1) return false;
|
||||
return p.revents & POLLIN;
|
||||
}
|
||||
|
||||
bool Socket::WaitForReadyWrite() {
|
||||
struct pollfd p;
|
||||
p.fd = socket_;
|
||||
p.events = POLLOUT;
|
||||
// We call poll with one element in the poll fds array (first and second
|
||||
// arguments), also we set the timeout to -1 to block indefinitely until an
|
||||
// event occurs.
|
||||
int ret = poll(&p, 1, -1);
|
||||
if (ret < 1) return false;
|
||||
return p.revents & POLLOUT;
|
||||
}
|
||||
|
||||
} // namespace io::network
|
||||
|
@ -153,7 +153,41 @@ class Socket {
|
||||
* == 0 if the client closed the connection
|
||||
* < 0 if an error has occurred
|
||||
*/
|
||||
int Read(void *buffer, size_t len, bool nonblock = false);
|
||||
ssize_t Read(void *buffer, size_t len, bool nonblock = false);
|
||||
|
||||
/**
|
||||
* Wait until the socket becomes ready for a `Read` operation.
|
||||
* This function blocks indefinitely waiting for the socket to change its
|
||||
* state. This function is useful when you need a blocking operation on a
|
||||
* non-blocking socket, you can call this function to ensure that your next
|
||||
* `Read` operation will succeed.
|
||||
*
|
||||
* The function returns `true` if the wait succeded (there is data waiting to
|
||||
* be read from the socket) and returns `false` if the wait failed (the socket
|
||||
* was closed or something else bad happened).
|
||||
*
|
||||
* @return wait success status:
|
||||
* true if the wait succeeded
|
||||
* false if the wait failed
|
||||
*/
|
||||
bool WaitForReadyRead();
|
||||
|
||||
/**
|
||||
* Wait until the socket becomes ready for a `Write` operation.
|
||||
* This function blocks indefinitely waiting for the socket to change its
|
||||
* state. This function is useful when you need a blocking operation on a
|
||||
* non-blocking socket, you can call this function to ensure that your next
|
||||
* `Write` operation will succeed.
|
||||
*
|
||||
* The function returns `true` if the wait succeded (the socket can be written
|
||||
* to) and returns `false` if the wait failed (the socket was closed or
|
||||
* something else bad happened).
|
||||
*
|
||||
* @return wait success status:
|
||||
* true if the wait succeeded
|
||||
* false if the wait failed
|
||||
*/
|
||||
bool WaitForReadyWrite();
|
||||
|
||||
private:
|
||||
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}
|
||||
|
@ -28,6 +28,7 @@ using communication::bolt::SessionData;
|
||||
using SessionT = communication::bolt::Session<communication::InputStream,
|
||||
communication::OutputStream>;
|
||||
using ServerT = communication::Server<SessionT, SessionData>;
|
||||
using communication::ServerContext;
|
||||
|
||||
// General purpose flags.
|
||||
DEFINE_string(interface, "0.0.0.0",
|
||||
@ -41,6 +42,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
|
||||
"Time in seconds after which inactive sessions will be "
|
||||
"closed.",
|
||||
FLAG_IN_RANGE(1, INT32_MAX));
|
||||
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||
DEFINE_string(key_file, "", "Key file to use.");
|
||||
DEFINE_string(log_file, "", "Path to where the log should be stored.");
|
||||
DEFINE_HIDDEN_string(
|
||||
log_link_basename, "",
|
||||
@ -142,6 +145,9 @@ int WithInit(int argc, char **argv,
|
||||
stats::InitStatsLogging(get_stats_prefix());
|
||||
utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); });
|
||||
|
||||
// Initialize the communication library.
|
||||
communication::Init();
|
||||
|
||||
// Start memory warning logger.
|
||||
utils::Scheduler mem_log_scheduler;
|
||||
if (FLAGS_memory_warning_threshold > 0) {
|
||||
@ -160,9 +166,17 @@ void SingleNodeMain() {
|
||||
google::SetUsageMessage("Memgraph single-node database server");
|
||||
database::SingleNode db;
|
||||
SessionData session_data{db};
|
||||
|
||||
ServerContext context;
|
||||
std::string service_name = "Bolt";
|
||||
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
|
||||
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
|
||||
service_name = "BoltS";
|
||||
}
|
||||
|
||||
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
||||
session_data, FLAGS_session_inactivity_timeout, "Bolt",
|
||||
FLAGS_num_workers);
|
||||
session_data, &context, FLAGS_session_inactivity_timeout,
|
||||
service_name, FLAGS_num_workers);
|
||||
|
||||
// Setup telemetry
|
||||
std::experimental::optional<telemetry::Telemetry> telemetry;
|
||||
@ -214,9 +228,17 @@ void MasterMain() {
|
||||
|
||||
database::Master db;
|
||||
SessionData session_data{db};
|
||||
|
||||
ServerContext context;
|
||||
std::string service_name = "Bolt";
|
||||
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
|
||||
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
|
||||
service_name = "BoltS";
|
||||
}
|
||||
|
||||
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
||||
session_data, FLAGS_session_inactivity_timeout, "Bolt",
|
||||
FLAGS_num_workers);
|
||||
session_data, &context, FLAGS_session_inactivity_timeout,
|
||||
service_name, FLAGS_num_workers);
|
||||
|
||||
// Handler for regular termination signals
|
||||
auto shutdown = [&server] {
|
||||
|
@ -42,10 +42,11 @@ class TestSession {
|
||||
input_stream_.Shift(size + 2);
|
||||
}
|
||||
|
||||
communication::InputStream input_stream_;
|
||||
communication::OutputStream output_stream_;
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
};
|
||||
|
||||
using ContextT = communication::ServerContext;
|
||||
using ServerT = communication::Server<TestSession, TestData>;
|
||||
|
||||
void client_run(int num, const char *interface, uint16_t port,
|
||||
|
@ -32,8 +32,8 @@ class TestSession {
|
||||
output_stream_.Write(input_stream_.data(), input_stream_.size());
|
||||
}
|
||||
|
||||
communication::InputStream input_stream_;
|
||||
communication::OutputStream output_stream_;
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
};
|
||||
|
||||
std::atomic<bool> run{true};
|
||||
@ -63,8 +63,9 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
|
||||
TestData data;
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
int Nc = N * 3;
|
||||
communication::Server<TestSession, TestData> server(endpoint, data, -1,
|
||||
"Test", N);
|
||||
communication::ServerContext context;
|
||||
communication::Server<TestSession, TestData> server(endpoint, data, &context,
|
||||
-1, "Test", N);
|
||||
|
||||
const auto &ep = server.endpoint();
|
||||
// start clients
|
||||
|
@ -21,7 +21,8 @@ TEST(Network, Server) {
|
||||
// initialize server
|
||||
TestData session_data;
|
||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||
ServerT server(endpoint, session_data, -1, "Test", N);
|
||||
ContextT context;
|
||||
ServerT server(endpoint, session_data, &context, -1, "Test", N);
|
||||
|
||||
const auto &ep = server.endpoint();
|
||||
// start clients
|
||||
|
@ -22,7 +22,8 @@ TEST(Network, SessionLeak) {
|
||||
|
||||
// initialize server
|
||||
TestData session_data;
|
||||
ServerT server(endpoint, session_data, -1, "Test", 2);
|
||||
ContextT context;
|
||||
ServerT server(endpoint, session_data, &context, -1, "Test", 2);
|
||||
|
||||
// start clients
|
||||
int N = 50;
|
||||
|
@ -1,2 +1,5 @@
|
||||
# telemetry test binaries
|
||||
add_subdirectory(telemetry)
|
||||
|
||||
# ssl test binaries
|
||||
add_subdirectory(ssl)
|
||||
|
@ -6,3 +6,11 @@
|
||||
- server.py # server script
|
||||
- ../../../build_debug/tests/integration/telemetry/client # client binary
|
||||
- ../../../build_debug/tests/manual/kvstore_console # kvstore console binary
|
||||
|
||||
- name: integration__ssl
|
||||
cd: ssl
|
||||
commands: ./runner.sh
|
||||
infiles:
|
||||
- runner.sh # runner script
|
||||
- ../../../build_debug/tests/integration/ssl/tester # tester binary
|
||||
enable_network: true
|
||||
|
6
tests/integration/ssl/CMakeLists.txt
Normal file
6
tests/integration/ssl/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
||||
set(target_name memgraph__integration__ssl)
|
||||
set(tester_target_name ${target_name}__tester)
|
||||
|
||||
add_executable(${tester_target_name} tester.cpp)
|
||||
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
|
||||
target_link_libraries(${tester_target_name} memgraph_lib kvstore_dummy_lib)
|
78
tests/integration/ssl/runner.sh
Executable file
78
tests/integration/ssl/runner.sh
Executable file
@ -0,0 +1,78 @@
|
||||
#!/bin/bash -e
|
||||
|
||||
pushd () { command pushd "$@" > /dev/null; }
|
||||
popd () { command popd "$@" > /dev/null; }
|
||||
die () { printf "\033[1;31m~~ Test failed! ~~\033[0m\n\n"; exit 1; }
|
||||
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
cd "$DIR"
|
||||
|
||||
tmpdir=/tmp/memgraph_integration_ssl
|
||||
|
||||
if [ -d $tmpdir ]; then
|
||||
rm -rf $tmpdir
|
||||
fi
|
||||
mkdir -p $tmpdir
|
||||
cd $tmpdir
|
||||
|
||||
easyrsa="EasyRSA-3.0.4"
|
||||
wget -nv http://deps.memgraph.io/$easyrsa.tgz
|
||||
|
||||
tar -xf $easyrsa.tgz
|
||||
mv $easyrsa ca1
|
||||
cp -r ca1 ca2
|
||||
|
||||
for i in ca1 ca2; do
|
||||
pushd $i
|
||||
./easyrsa --batch init-pki
|
||||
./easyrsa --batch build-ca nopass
|
||||
./easyrsa --batch build-server-full server nopass
|
||||
./easyrsa --batch build-client-full client nopass
|
||||
popd
|
||||
done
|
||||
|
||||
binary_dir="$DIR/../../../build"
|
||||
if [ ! -d $binary_dir ]; then
|
||||
binary_dir="$DIR/../../../build_debug"
|
||||
fi
|
||||
|
||||
set +e
|
||||
|
||||
echo
|
||||
for i in ca1 ca2; do
|
||||
for j in none ca1 ca2; do
|
||||
for k in false true; do
|
||||
printf "\033[1;36m~~ Server CA: $i; Client CA: $j; Verify peer: $k ~~\033[0m\n"
|
||||
|
||||
if [ "$j" == "none" ]; then
|
||||
client_key=""
|
||||
client_cert=""
|
||||
else
|
||||
client_key=$j/pki/private/client.key
|
||||
client_cert=$j/pki/issued/client.crt
|
||||
fi
|
||||
|
||||
$binary_dir/tests/integration/ssl/tester \
|
||||
--server-key-file=$i/pki/private/server.key \
|
||||
--server-cert-file=$i/pki/issued/server.crt \
|
||||
--server-ca-file=$i/pki/ca.crt \
|
||||
--server-verify-peer=$k \
|
||||
--client-key-file=$client_key \
|
||||
--client-cert-file=$client_cert
|
||||
|
||||
exitcode=$?
|
||||
|
||||
if [ "$i" == "$j" ]; then
|
||||
[ $exitcode -eq 0 ] || die
|
||||
else
|
||||
if $k; then
|
||||
[ $exitcode -ne 0 ] || die
|
||||
else
|
||||
[ $exitcode -eq 0 ] || die
|
||||
fi
|
||||
fi
|
||||
|
||||
printf "\033[1;32m~~ Test ok! ~~\033[0m\n\n"
|
||||
done
|
||||
done
|
||||
done
|
76
tests/integration/ssl/tester.cpp
Normal file
76
tests/integration/ssl/tester.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
#include <atomic>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/client.hpp"
|
||||
#include "communication/server.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
|
||||
DEFINE_string(server_cert_file, "", "Server certificate file to use.");
|
||||
DEFINE_string(server_key_file, "", "Server key file to use.");
|
||||
DEFINE_string(server_ca_file, "", "Server CA file to use.");
|
||||
DEFINE_bool(server_verify_peer, false, "Set to true to verify the peer.");
|
||||
|
||||
DEFINE_string(client_cert_file, "", "Client certificate file to use.");
|
||||
DEFINE_string(client_key_file, "", "Client key file to use.");
|
||||
|
||||
const std::string message = "ssl echo test";
|
||||
|
||||
struct EchoData {};
|
||||
|
||||
class EchoSession {
|
||||
public:
|
||||
EchoSession(EchoData &, communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
if (input_stream_.size() < message.size()) return;
|
||||
LOG(INFO) << "Server received message.";
|
||||
if (!output_stream_.Write(input_stream_.data(), message.size())) {
|
||||
throw utils::BasicException("Output stream write failed!");
|
||||
}
|
||||
LOG(INFO) << "Server sent message.";
|
||||
input_stream_.Shift(message.size());
|
||||
}
|
||||
|
||||
private:
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
// Initialize the communication stack.
|
||||
communication::Init();
|
||||
|
||||
// Initialize the server.
|
||||
EchoData echo_data;
|
||||
communication::ServerContext server_context(
|
||||
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
|
||||
FLAGS_server_verify_peer);
|
||||
communication::Server<EchoSession, EchoData> server(
|
||||
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1);
|
||||
|
||||
// Initialize the client.
|
||||
communication::ClientContext client_context(FLAGS_client_key_file,
|
||||
FLAGS_client_cert_file);
|
||||
communication::Client client(&client_context);
|
||||
|
||||
// Connect to the server.
|
||||
CHECK(client.Connect(server.endpoint())) << "Couldn't connect to server!";
|
||||
|
||||
// Perform echo.
|
||||
CHECK(client.Write(message)) << "Client couldn't send message!";
|
||||
LOG(INFO) << "Client sent message.";
|
||||
CHECK(client.Read(message.size())) << "Client couldn't receive message!";
|
||||
LOG(INFO) << "Client received message.";
|
||||
CHECK(std::string(reinterpret_cast<const char *>(client.GetData()),
|
||||
message.size()) == message)
|
||||
<< "Received message isn't equal to sent message!";
|
||||
|
||||
return 0;
|
||||
}
|
@ -10,6 +10,7 @@
|
||||
#include "io/network/endpoint.hpp"
|
||||
|
||||
using EndpointT = io::network::Endpoint;
|
||||
using ContextT = communication::ClientContext;
|
||||
using ClientT = communication::bolt::Client;
|
||||
using QueryDataT = communication::bolt::QueryData;
|
||||
using communication::bolt::DecodedValue;
|
||||
@ -18,23 +19,23 @@ class BoltClient {
|
||||
public:
|
||||
BoltClient(const std::string &address, uint16_t port,
|
||||
const std::string &username, const std::string &password,
|
||||
const std::string & = "") {
|
||||
const std::string & = "", bool use_ssl = false)
|
||||
: context_(use_ssl), client_(context_) {
|
||||
EndpointT endpoint(address, port);
|
||||
client_ = std::make_unique<ClientT>();
|
||||
|
||||
if (!client_->Connect(endpoint, username, password)) {
|
||||
if (!client_.Connect(endpoint, username, password)) {
|
||||
LOG(FATAL) << "Could not connect to: " << endpoint;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
QueryDataT Execute(const std::string &query,
|
||||
const std::map<std::string, DecodedValue> ¶meters) {
|
||||
return client_->Execute(query, parameters);
|
||||
return client_.Execute(query, parameters);
|
||||
}
|
||||
|
||||
void Close() { client_->Close(); }
|
||||
void Close() { client_.Close(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<ClientT> client_;
|
||||
ContextT context_;
|
||||
ClientT client_;
|
||||
};
|
||||
|
@ -331,11 +331,14 @@ int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
stats::InitStatsLogging(
|
||||
fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario));
|
||||
|
||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
Client client;
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "utils/exceptions.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
using communication::ClientContext;
|
||||
using communication::bolt::Client;
|
||||
using communication::bolt::DecodedValue;
|
||||
using io::network::Endpoint;
|
||||
|
@ -16,6 +16,7 @@ DEFINE_int32(num_workers, 1, "Number of workers");
|
||||
DEFINE_string(output, "", "Output file");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
DEFINE_int32(duration, 30, "Number of seconds to execute benchmark");
|
||||
|
||||
DEFINE_string(group, "unknown", "Test group name");
|
||||
@ -97,7 +98,8 @@ class TestClient {
|
||||
std::thread runner_thread_;
|
||||
|
||||
private:
|
||||
Client client_;
|
||||
communication::ClientContext context_{FLAGS_use_ssl};
|
||||
Client client_{&context_};
|
||||
};
|
||||
|
||||
void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) {
|
||||
|
@ -271,12 +271,15 @@ int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
nlohmann::json config;
|
||||
std::cin >> config;
|
||||
|
||||
auto independent_nodes_ids = [&] {
|
||||
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
||||
Client client;
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
using communication::bolt::DecodedValue;
|
||||
|
||||
@ -58,7 +59,8 @@ void ExecuteQueries(const std::vector<std::string> &queries,
|
||||
for (int i = 0; i < FLAGS_num_workers; ++i) {
|
||||
threads.push_back(std::thread([&]() {
|
||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
Client client;
|
||||
ClientContext context(FLAGS_use_ssl);
|
||||
Client client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||
}
|
||||
@ -100,6 +102,8 @@ int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
std::ifstream ifile;
|
||||
std::istream *istream{&std::cin};
|
||||
|
||||
|
@ -71,6 +71,12 @@ target_link_libraries(${test_prefix}sl_position_and_count memgraph_lib kvstore_d
|
||||
add_manual_test(stripped_timing.cpp)
|
||||
target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib)
|
||||
|
||||
add_manual_test(ssl_client.cpp)
|
||||
target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib)
|
||||
|
||||
add_manual_test(ssl_server.cpp)
|
||||
target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib)
|
||||
|
||||
add_manual_test(xorshift.cpp)
|
||||
target_link_libraries(${test_prefix}xorshift mg-utils)
|
||||
|
||||
|
@ -10,15 +10,20 @@ DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
// TODO: handle endpoint exception
|
||||
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
|
||||
FLAGS_port);
|
||||
communication::bolt::Client client;
|
||||
|
||||
communication::ClientContext context(FLAGS_use_ssl);
|
||||
communication::bolt::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;
|
||||
|
||||
|
65
tests/manual/ssl_client.cpp
Normal file
65
tests/manual/ssl_client.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
#include <gflags/gflags.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/client.hpp"
|
||||
#include "io/network/endpoint.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 54321, "Server port");
|
||||
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||
DEFINE_string(key_file, "", "Key file to use.");
|
||||
|
||||
bool EchoMessage(communication::Client &client, const std::string &data) {
|
||||
uint16_t size = data.size();
|
||||
if (!client.Write(reinterpret_cast<const uint8_t *>(&size), sizeof(size))) {
|
||||
LOG(WARNING) << "Couldn't send data size!";
|
||||
return false;
|
||||
}
|
||||
if (!client.Write(data)) {
|
||||
LOG(WARNING) << "Couldn't send data!";
|
||||
return false;
|
||||
}
|
||||
|
||||
client.ClearData();
|
||||
if (!client.Read(size)) {
|
||||
LOG(WARNING) << "Couldn't receive data!";
|
||||
return false;
|
||||
}
|
||||
if (std::string(reinterpret_cast<const char *>(client.GetData()), size) !=
|
||||
data) {
|
||||
LOG(WARNING) << "Received data isn't equal to sent data!";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
|
||||
communication::ClientContext context(FLAGS_key_file, FLAGS_cert_file);
|
||||
communication::Client client(&context);
|
||||
|
||||
if (!client.Connect(endpoint)) return 1;
|
||||
|
||||
bool success = true;
|
||||
while (true) {
|
||||
std::string s;
|
||||
std::getline(std::cin, s);
|
||||
if (s == "") break;
|
||||
if (!EchoMessage(client, s)) {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send server shutdown signal. The call will fail, we don't care.
|
||||
EchoMessage(client, "");
|
||||
|
||||
return success ? 0 : 1;
|
||||
}
|
73
tests/manual/ssl_server.cpp
Normal file
73
tests/manual/ssl_server.cpp
Normal file
@ -0,0 +1,73 @@
|
||||
#include <atomic>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "communication/server.hpp"
|
||||
#include "utils/exceptions.hpp"
|
||||
|
||||
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 54321, "Server port");
|
||||
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||
DEFINE_string(key_file, "", "Key file to use.");
|
||||
DEFINE_string(ca_file, "", "CA file to use.");
|
||||
DEFINE_bool(verify_peer, false, "Set to true to verify the peer.");
|
||||
|
||||
struct EchoData {
|
||||
std::atomic<bool> alive{true};
|
||||
};
|
||||
|
||||
class EchoSession {
|
||||
public:
|
||||
EchoSession(EchoData &data, communication::InputStream &input_stream,
|
||||
communication::OutputStream &output_stream)
|
||||
: data_(data),
|
||||
input_stream_(input_stream),
|
||||
output_stream_(output_stream) {}
|
||||
|
||||
void Execute() {
|
||||
if (input_stream_.size() < 2) return;
|
||||
const uint8_t *data = input_stream_.data();
|
||||
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data());
|
||||
input_stream_.Resize(size + 2);
|
||||
if (input_stream_.size() < size + 2) return;
|
||||
if (size == 0) {
|
||||
LOG(INFO) << "Server received EOF message";
|
||||
data_.alive.store(false);
|
||||
return;
|
||||
}
|
||||
LOG(INFO) << "Server received '"
|
||||
<< std::string(reinterpret_cast<const char *>(data + 2), size)
|
||||
<< "'";
|
||||
if (!output_stream_.Write(data + 2, size)) {
|
||||
throw utils::BasicException("Output stream write failed!");
|
||||
}
|
||||
input_stream_.Shift(size + 2);
|
||||
}
|
||||
|
||||
private:
|
||||
EchoData &data_;
|
||||
communication::InputStream &input_stream_;
|
||||
communication::OutputStream &output_stream_;
|
||||
};
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
// Initialize the server.
|
||||
EchoData echo_data;
|
||||
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
|
||||
FLAGS_ca_file, FLAGS_verify_peer);
|
||||
communication::Server<EchoSession, EchoData> server(endpoint, echo_data,
|
||||
&context, -1, "SSL", 1);
|
||||
|
||||
while (echo_data.alive) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
1
tests/stress/.gitignore
vendored
1
tests/stress/.gitignore
vendored
@ -1 +1,2 @@
|
||||
.long_running_stats
|
||||
*.pem
|
||||
|
@ -10,6 +10,10 @@
|
||||
commands: TIMEOUT=600 ./continuous_integration --properties-on-disk
|
||||
infiles: *STRESS_INFILES
|
||||
|
||||
- name: stress_ssl
|
||||
commands: TIMEOUT=600 ./continuous_integration --use-ssl
|
||||
infiles: *STRESS_INFILES
|
||||
|
||||
- name: stress_large
|
||||
project: release
|
||||
commands: TIMEOUT=43200 ./continuous_integration --large-dataset
|
||||
|
@ -146,14 +146,14 @@ def connection_argument_parser():
|
||||
'''
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
|
||||
parser.add_argument('--endpoint', type=str, default='localhost:7687',
|
||||
parser.add_argument('--endpoint', type=str, default='127.0.0.1:7687',
|
||||
help='DBMS instance endpoint. '
|
||||
'Bolt protocol is the only option.')
|
||||
parser.add_argument('--username', type=str, default='neo4j',
|
||||
help='DBMS instance username.')
|
||||
parser.add_argument('--password', type=int, default='1234',
|
||||
help='DBMS instance password.')
|
||||
parser.add_argument('--ssl-enabled', action='store_true',
|
||||
parser.add_argument('--use-ssl', action='store_true',
|
||||
help="Is SSL enabled?")
|
||||
return parser
|
||||
|
||||
@ -163,7 +163,7 @@ def bolt_session(url, auth, ssl=False):
|
||||
'''
|
||||
with wrapper around Bolt session.
|
||||
|
||||
:param url: str, e.g. "bolt://localhost:7687"
|
||||
:param url: str, e.g. "bolt://127.0.0.1:7687"
|
||||
:param auth: auth method, goes directly to the Bolt driver constructor
|
||||
:param ssl: bool, is ssl enabled
|
||||
'''
|
||||
@ -183,14 +183,15 @@ def argument_session(args):
|
||||
:return: Bolt session context manager based on program arguments
|
||||
'''
|
||||
return bolt_session('bolt://' + args.endpoint,
|
||||
(args.username, str(args.password)))
|
||||
(args.username, str(args.password)),
|
||||
args.use_ssl)
|
||||
|
||||
|
||||
def argument_driver(args, ssl=False):
|
||||
def argument_driver(args):
|
||||
return GraphDatabase.driver(
|
||||
'bolt://' + args.endpoint,
|
||||
auth=(args.username, str(args.password)),
|
||||
encrypted=ssl)
|
||||
encrypted=args.use_ssl)
|
||||
|
||||
# This class is used to create and cache sessions. Session is cached by args
|
||||
# used to create it and process' pid in which it was created. This makes it easy
|
||||
|
@ -73,6 +73,8 @@ BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
|
||||
BUILD_DIR = os.path.join(BASE_DIR, "build")
|
||||
CONFIG_DIR = os.path.join(BASE_DIR, "config")
|
||||
MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements")
|
||||
KEY_FILE = os.path.join(SCRIPT_DIR, ".key.pem")
|
||||
CERT_FILE = os.path.join(SCRIPT_DIR, ".cert.pem")
|
||||
|
||||
# long running stats file
|
||||
STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats")
|
||||
@ -132,6 +134,8 @@ parser.add_argument("--python", default = os.path.join(SCRIPT_DIR,
|
||||
"ve3", "bin", "python3"), type = str)
|
||||
parser.add_argument("--large-dataset", action = "store_const",
|
||||
const = True, default = False)
|
||||
parser.add_argument("--use-ssl", action = "store_const",
|
||||
const = True, default = False)
|
||||
parser.add_argument("--verbose", action = "store_const",
|
||||
const = True, default = False)
|
||||
args = parser.parse_args()
|
||||
@ -140,6 +144,14 @@ args = parser.parse_args()
|
||||
if not os.path.exists(args.memgraph):
|
||||
args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph")
|
||||
|
||||
# generate temporary SSL certs
|
||||
if args.use_ssl:
|
||||
# https://unix.stackexchange.com/questions/104171/create-ssl-certificate-non-interactively
|
||||
subj = "/C=HR/ST=Zagreb/L=Zagreb/O=Memgraph/CN=db.memgraph.com"
|
||||
subprocess.run(["openssl", "req", "-new", "-newkey", "rsa:4096",
|
||||
"-days", "365", "-nodes", "-x509", "-subj", subj,
|
||||
"-keyout", KEY_FILE, "-out", CERT_FILE], check=True)
|
||||
|
||||
# start memgraph
|
||||
cwd = os.path.dirname(args.memgraph)
|
||||
cmd = [args.memgraph, "--num-workers=" + str(THREADS)]
|
||||
@ -151,6 +163,8 @@ if args.durability_directory:
|
||||
cmd += ["--durability-directory", args.durability_directory]
|
||||
if args.properties_on_disk:
|
||||
cmd += ["--properties-on-disk", "id,x"]
|
||||
if args.use_ssl:
|
||||
cmd += ["--cert-file", CERT_FILE, "--key-file", KEY_FILE]
|
||||
proc_mg = subprocess.Popen(cmd, cwd = cwd,
|
||||
env = {"MEMGRAPH_CONFIG": args.config})
|
||||
time.sleep(1.0)
|
||||
@ -167,6 +181,8 @@ def cleanup():
|
||||
runtimes = {}
|
||||
dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET
|
||||
for test in dataset:
|
||||
if args.use_ssl:
|
||||
test["options"] += ["--use-ssl"]
|
||||
runtime = run_test(args, **test)
|
||||
runtimes[os.path.splitext(test["test"])[0]] = runtime
|
||||
|
||||
@ -176,6 +192,11 @@ ret_mg = proc_mg.wait()
|
||||
if ret_mg != 0:
|
||||
raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg))
|
||||
|
||||
# cleanup certificates
|
||||
if args.use_ssl:
|
||||
os.remove(KEY_FILE)
|
||||
os.remove(CERT_FILE)
|
||||
|
||||
# measurements
|
||||
measurements = ""
|
||||
for key, value in runtimes.items():
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
using EndpointT = io::network::Endpoint;
|
||||
using ClientContextT = communication::ClientContext;
|
||||
using ClientT = communication::bolt::Client;
|
||||
using DecodedValueT = communication::bolt::DecodedValue;
|
||||
using QueryDataT = communication::bolt::QueryData;
|
||||
@ -17,6 +18,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
|
||||
DEFINE_int32(port, 7687, "Server port");
|
||||
DEFINE_string(username, "", "Username for the database");
|
||||
DEFINE_string(password, "", "Password for the database");
|
||||
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||
|
||||
DEFINE_int32(vertex_count, 0,
|
||||
"The average number of vertices in the graph per worker");
|
||||
@ -51,7 +53,7 @@ class GraphSession {
|
||||
}
|
||||
|
||||
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
||||
client_ = std::make_unique<ClientT>();
|
||||
client_ = std::make_unique<ClientT>(&context_);
|
||||
|
||||
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
throw utils::BasicException("Couldn't connect to server!");
|
||||
@ -60,6 +62,7 @@ class GraphSession {
|
||||
|
||||
private:
|
||||
uint64_t id_;
|
||||
ClientContextT context_{FLAGS_use_ssl};
|
||||
std::unique_ptr<ClientT> client_;
|
||||
|
||||
std::set<uint64_t> vertices_;
|
||||
@ -362,6 +365,8 @@ int main(int argc, char **argv) {
|
||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
communication::Init();
|
||||
|
||||
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
|
||||
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";
|
||||
|
||||
@ -369,7 +374,8 @@ int main(int argc, char **argv) {
|
||||
|
||||
// create client
|
||||
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
||||
ClientT client;
|
||||
ClientContextT context(FLAGS_use_ssl);
|
||||
ClientT client(&context);
|
||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||
throw utils::BasicException("Couldn't connect to server!");
|
||||
}
|
||||
|
@ -52,8 +52,9 @@ bool QueryServer(io::network::Socket &socket) {
|
||||
TEST(NetworkTimeouts, InactiveSession) {
|
||||
// Instantiate the server and set the session timeout to 2 seconds.
|
||||
TestData test_data;
|
||||
communication::ServerContext context;
|
||||
communication::Server<TestSession, TestData> server{
|
||||
{"127.0.0.1", 0}, test_data, 2, "Test", 1};
|
||||
{"127.0.0.1", 0}, test_data, &context, 2, "Test", 1};
|
||||
|
||||
// Create the client and connect to the server.
|
||||
io::network::Socket client;
|
||||
|
94
tests/unit/socket.cpp
Normal file
94
tests/unit/socket.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
#include <chrono>
|
||||
#include <csignal>
|
||||
#include <thread>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "io/network/socket.hpp"
|
||||
#include "utils/timer.hpp"
|
||||
|
||||
TEST(Socket, WaitForReadyRead) {
|
||||
io::network::Socket server;
|
||||
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
|
||||
ASSERT_TRUE(server.Listen(1024));
|
||||
|
||||
std::thread thread([&server] {
|
||||
io::network::Socket client;
|
||||
ASSERT_TRUE(client.Connect(server.endpoint()));
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
ASSERT_TRUE(client.Write("test"));
|
||||
});
|
||||
|
||||
uint8_t buff[100];
|
||||
auto client = server.Accept();
|
||||
ASSERT_TRUE(client);
|
||||
|
||||
client->SetNonBlocking();
|
||||
|
||||
ASSERT_EQ(client->Read(buff, sizeof(buff)), -1);
|
||||
ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR);
|
||||
|
||||
utils::Timer timer;
|
||||
ASSERT_TRUE(client->WaitForReadyRead());
|
||||
ASSERT_GT(timer.Elapsed().count(), 1.0);
|
||||
|
||||
ASSERT_GT(client->Read(buff, sizeof(buff)), 0);
|
||||
|
||||
thread.join();
|
||||
}
|
||||
|
||||
TEST(Socket, WaitForReadyWrite) {
|
||||
io::network::Socket server;
|
||||
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
|
||||
ASSERT_TRUE(server.Listen(1024));
|
||||
|
||||
std::thread thread([&server] {
|
||||
uint8_t buff[10000];
|
||||
io::network::Socket client;
|
||||
ASSERT_TRUE(client.Connect(server.endpoint()));
|
||||
client.SetNonBlocking();
|
||||
|
||||
// Decrease the TCP read buffer.
|
||||
int len = 1024;
|
||||
ASSERT_EQ(setsockopt(client.fd(), SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)),
|
||||
0);
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||
while (true) {
|
||||
int ret = client.Read(buff, sizeof(buff));
|
||||
if (ret == -1 && errno != EAGAIN && errno != EWOULDBLOCK &&
|
||||
errno != EINTR) {
|
||||
std::raise(SIGPIPE);
|
||||
} else if (ret == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
auto client = server.Accept();
|
||||
ASSERT_TRUE(client);
|
||||
|
||||
client->SetNonBlocking();
|
||||
|
||||
// Decrease the TCP write buffer.
|
||||
int len = 1024;
|
||||
ASSERT_EQ(setsockopt(client->fd(), SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)),
|
||||
0);
|
||||
|
||||
utils::Timer timer;
|
||||
for (int i = 0; i < 1000000; ++i) {
|
||||
ASSERT_TRUE(client->Write("test"));
|
||||
}
|
||||
ASSERT_GT(timer.Elapsed().count(), 1.0);
|
||||
|
||||
client->Close();
|
||||
|
||||
thread.join();
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
Loading…
Reference in New Issue
Block a user