Implement SSL support for servers and clients
Summary: This diff implements OpenSSL support in the network stack. Currently SSL support is only enabled for Bolt connections, support for RPC connections will be added in another diff. Reviewers: buda, teon.banek Reviewed By: buda Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1328
This commit is contained in:
parent
44821a918c
commit
1d448d40ca
@ -10,6 +10,7 @@
|
|||||||
* Static vertices/edges id generators exposed through the Id Cypher function.
|
* Static vertices/edges id generators exposed through the Id Cypher function.
|
||||||
* Properties on disk added.
|
* Properties on disk added.
|
||||||
* Telemetry added.
|
* Telemetry added.
|
||||||
|
* SSL support added.
|
||||||
* Add `toString` function to openCypher
|
* Add `toString` function to openCypher
|
||||||
|
|
||||||
### Bug Fixes and Other Changes
|
### Bug Fixes and Other Changes
|
||||||
|
@ -140,6 +140,9 @@ endif()
|
|||||||
set(Boost_USE_STATIC_LIBS ON)
|
set(Boost_USE_STATIC_LIBS ON)
|
||||||
find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization)
|
find_package(Boost 1.62 REQUIRED COMPONENTS iostreams serialization)
|
||||||
|
|
||||||
|
# OpenSSL
|
||||||
|
find_package(OpenSSL REQUIRED)
|
||||||
|
|
||||||
set(libs_dir ${CMAKE_SOURCE_DIR}/libs)
|
set(libs_dir ${CMAKE_SOURCE_DIR}/libs)
|
||||||
add_subdirectory(libs EXCLUDE_FROM_ALL)
|
add_subdirectory(libs EXCLUDE_FROM_ALL)
|
||||||
|
|
||||||
@ -320,6 +323,8 @@ set(CPACK_DEBIAN_PACKAGE_DESCRIPTION "${CPACK_PACKAGE_DESCRIPTION_SUMMARY}
|
|||||||
Contains Memgraph, the graph database. It aims to deliver developers the
|
Contains Memgraph, the graph database. It aims to deliver developers the
|
||||||
speed, simplicity and scale required to build the next generation of
|
speed, simplicity and scale required to build the next generation of
|
||||||
applications driver by real-time connected data.")
|
applications driver by real-time connected data.")
|
||||||
|
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
|
||||||
|
set(CPACK_DEBIAN_memgraph_PACKAGE_DEPENDS "openssl (>= 1.1.0)")
|
||||||
|
|
||||||
# RPM specific
|
# RPM specific
|
||||||
set(CPACK_RPM_PACKAGE_URL https://memgraph.com)
|
set(CPACK_RPM_PACKAGE_URL https://memgraph.com)
|
||||||
@ -335,6 +340,8 @@ set(CPACK_RPM_USER_BINARY_SPECFILE "${CMAKE_SOURCE_DIR}/release/rpm/memgraph.spe
|
|||||||
set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database.
|
set(CPACK_RPM_PACKAGE_DESCRIPTION "Contains Memgraph, the graph database.
|
||||||
It aims to deliver developers the speed, simplicity and scale required to build
|
It aims to deliver developers the speed, simplicity and scale required to build
|
||||||
the next generation of applications driver by real-time connected data.")
|
the next generation of applications driver by real-time connected data.")
|
||||||
|
# Add `openssl` package to dependencies list. Used to generate SSL certificates.
|
||||||
|
set(CPACK_RPM_memgraph_PACKAGE_REQUIRES "openssl >= 1.0.0")
|
||||||
|
|
||||||
# All variables must be set before including.
|
# All variables must be set before including.
|
||||||
include(CPack)
|
include(CPack)
|
||||||
|
@ -16,6 +16,12 @@
|
|||||||
# Port the server should listen on.
|
# Port the server should listen on.
|
||||||
--port=7687
|
--port=7687
|
||||||
|
|
||||||
|
# Path to a SSL certificate file that should be used.
|
||||||
|
--cert-file=/etc/memgraph/ssl/cert.pem
|
||||||
|
|
||||||
|
# Path to a SSL key file that should be used.
|
||||||
|
--key-file=/etc/memgraph/ssl/key.pem
|
||||||
|
|
||||||
# Number of workers used by the Memgraph server. By default, this will be the
|
# Number of workers used by the Memgraph server. By default, this will be the
|
||||||
# number of processing units available on the machine.
|
# number of processing units available on the machine.
|
||||||
# --num-workers=8
|
# --num-workers=8
|
||||||
|
@ -13,11 +13,9 @@ from neo4j.v1 import GraphDatabase, basic_auth
|
|||||||
|
|
||||||
# Initialize and configure the driver.
|
# Initialize and configure the driver.
|
||||||
# * provide the correct URL where Memgraph is reachable;
|
# * provide the correct URL where Memgraph is reachable;
|
||||||
# * use an empty user name and password, and
|
# * use an empty user name and password.
|
||||||
# * disable encryption (not supported).
|
|
||||||
driver = GraphDatabase.driver("bolt://localhost:7687",
|
driver = GraphDatabase.driver("bolt://localhost:7687",
|
||||||
auth=basic_auth("", ""),
|
auth=basic_auth("", ""))
|
||||||
encrypted=False)
|
|
||||||
|
|
||||||
# Start a session in which queries are executed.
|
# Start a session in which queries are executed.
|
||||||
session = driver.session()
|
session = driver.session()
|
||||||
@ -51,9 +49,7 @@ The details about Java driver can be found
|
|||||||
[on GitHub](https://github.com/neo4j/neo4j-java-driver).
|
[on GitHub](https://github.com/neo4j/neo4j-java-driver).
|
||||||
|
|
||||||
The example below is equivalent to Python example. Major difference is that
|
The example below is equivalent to Python example. Major difference is that
|
||||||
`Config` object has to be created before the driver construction. Encryption
|
`Config` object has to be created before the driver construction.
|
||||||
has to be disabled by calling `withoutEncryption` method against the `Config`
|
|
||||||
builder.
|
|
||||||
|
|
||||||
```java
|
```java
|
||||||
import org.neo4j.driver.v1.*;
|
import org.neo4j.driver.v1.*;
|
||||||
@ -64,7 +60,7 @@ import java.util.*;
|
|||||||
public class JavaQuickStart {
|
public class JavaQuickStart {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
// Initialize driver.
|
// Initialize driver.
|
||||||
Config config = Config.build().withoutEncryption().toConfig();
|
Config config = Config.build().toConfig();
|
||||||
Driver driver = GraphDatabase.driver("bolt://localhost:7687",
|
Driver driver = GraphDatabase.driver("bolt://localhost:7687",
|
||||||
AuthTokens.basic("",""),
|
AuthTokens.basic("",""),
|
||||||
config);
|
config);
|
||||||
@ -93,9 +89,7 @@ public class JavaQuickStart {
|
|||||||
The details about Javascript driver can be found
|
The details about Javascript driver can be found
|
||||||
[on GitHub](https://github.com/neo4j/neo4j-javascript-driver).
|
[on GitHub](https://github.com/neo4j/neo4j-javascript-driver).
|
||||||
|
|
||||||
The Javascript example below is equivalent to Python and Java examples. SSL
|
The Javascript example below is equivalent to Python and Java examples.
|
||||||
can be disabled by passing `{encrypted: 'ENCRYPTION_OFF'}` during the driver
|
|
||||||
construction.
|
|
||||||
|
|
||||||
Here is an example related to `Node.js`. Memgraph doesn't have integrated
|
Here is an example related to `Node.js`. Memgraph doesn't have integrated
|
||||||
support for `WebSocket` which is required during the execution in any web
|
support for `WebSocket` which is required during the execution in any web
|
||||||
@ -109,8 +103,7 @@ proxy port.
|
|||||||
```javascript
|
```javascript
|
||||||
var neo4j = require('neo4j-driver').v1;
|
var neo4j = require('neo4j-driver').v1;
|
||||||
var driver = neo4j.driver("bolt://localhost:7687",
|
var driver = neo4j.driver("bolt://localhost:7687",
|
||||||
neo4j.auth.basic("neo4j", "1234"),
|
neo4j.auth.basic("neo4j", "1234"));
|
||||||
{encrypted: 'ENCRYPTION_OFF'});
|
|
||||||
var session = driver.session();
|
var session = driver.session();
|
||||||
|
|
||||||
function die() {
|
function die() {
|
||||||
@ -146,8 +139,7 @@ run_query("MATCH (n) DETACH DELETE n", function (result) {
|
|||||||
|
|
||||||
The C# driver is hosted
|
The C# driver is hosted
|
||||||
[on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below
|
[on GitHub](https://github.com/neo4j/neo4j-dotnet-driver). The example below
|
||||||
performs the same work as all of the previous examples. Encryption is disabled
|
performs the same work as all of the previous examples.
|
||||||
by setting `EncryptionLevel.NONE` on the `Config`.
|
|
||||||
|
|
||||||
```csh
|
```csh
|
||||||
using System;
|
using System;
|
||||||
@ -158,7 +150,6 @@ public class Basic {
|
|||||||
public static void Main(string[] args) {
|
public static void Main(string[] args) {
|
||||||
// Initialize the driver.
|
// Initialize the driver.
|
||||||
var config = Config.DefaultConfig;
|
var config = Config.DefaultConfig;
|
||||||
config.EncryptionLevel = EncryptionLevel.None;
|
|
||||||
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config))
|
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, config))
|
||||||
using(var session = driver.Session())
|
using(var session = driver.Session())
|
||||||
{
|
{
|
||||||
@ -176,6 +167,18 @@ public class Basic {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Secure Sockets Layer (SSL)
|
||||||
|
|
||||||
|
Secure connections are supported and enabled by default. The server initially
|
||||||
|
ships with a self-signed testing certificate. The certificate can be replaced
|
||||||
|
by editing the following parameters in `/etc/memgraph/memgraph.conf`:
|
||||||
|
```
|
||||||
|
--cert-file=/path/to/ssl/certificate.pem
|
||||||
|
--key-file=/path/to/ssl/privatekey.pem
|
||||||
|
```
|
||||||
|
To disable SSL support and use insecure connections to the database you should
|
||||||
|
set both parameters (`--cert-file` and `--key-file`) to empty values.
|
||||||
|
|
||||||
### Limitations
|
### Limitations
|
||||||
|
|
||||||
Memgraph is currently in early stage, and has a number of limitations we plan
|
Memgraph is currently in early stage, and has a number of limitations we plan
|
||||||
@ -186,9 +189,3 @@ to remove in future versions.
|
|||||||
Memgraph is currently single-user only. There is no way to control user
|
Memgraph is currently single-user only. There is no way to control user
|
||||||
privileges. The default user has read and write privileges over the whole
|
privileges. The default user has read and write privileges over the whole
|
||||||
database.
|
database.
|
||||||
|
|
||||||
#### Secure Sockets Layer (SSL)
|
|
||||||
|
|
||||||
Secure connections are not supported. For this reason each client
|
|
||||||
driver needs to be configured not to use encryption. Consult driver-specific
|
|
||||||
guides for details.
|
|
||||||
|
@ -183,7 +183,7 @@ After installing `neo4j-client`, connect to the running Memgraph instance by
|
|||||||
issuing the following shell command.
|
issuing the following shell command.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
neo4j-client --insecure -u "" -p "" localhost 7687
|
neo4j-client -u "" -p "" localhost 7687
|
||||||
```
|
```
|
||||||
|
|
||||||
After the client has started it should present a command prompt similar to:
|
After the client has started it should present a command prompt similar to:
|
||||||
@ -191,7 +191,7 @@ After the client has started it should present a command prompt similar to:
|
|||||||
```bash
|
```bash
|
||||||
neo4j-client 2.1.3
|
neo4j-client 2.1.3
|
||||||
Enter `:help` for usage hints.
|
Enter `:help` for usage hints.
|
||||||
Connected to 'neo4j://@localhost:7687' (insecure)
|
Connected to 'neo4j://@localhost:7687'
|
||||||
neo4j>
|
neo4j>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
1
init
1
init
@ -8,6 +8,7 @@ required_pkgs=(git arcanist # source code control
|
|||||||
curl wget # for downloading libs
|
curl wget # for downloading libs
|
||||||
uuid-dev default-jre-headless # required by antlr
|
uuid-dev default-jre-headless # required by antlr
|
||||||
libreadline-dev # for memgraph console
|
libreadline-dev # for memgraph console
|
||||||
|
libssl-dev
|
||||||
libboost-iostreams-dev
|
libboost-iostreams-dev
|
||||||
libboost-serialization-dev
|
libboost-serialization-dev
|
||||||
python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests
|
python3 python-virtualenv python3-pip # for qa, macro_benchmark and stress tests
|
||||||
|
@ -29,6 +29,16 @@ case "$1" in
|
|||||||
chmod 750 /var/log/memgraph || exit 1
|
chmod 750 /var/log/memgraph || exit 1
|
||||||
# Make examples directory immutable (optional)
|
# Make examples directory immutable (optional)
|
||||||
chattr +i -R /usr/share/memgraph/examples || true
|
chattr +i -R /usr/share/memgraph/examples || true
|
||||||
|
|
||||||
|
# Generate SSL certificates
|
||||||
|
if [ ! -d /etc/memgraph/ssl ]; then
|
||||||
|
mkdir /etc/memgraph/ssl || exit 1
|
||||||
|
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
|
||||||
|
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
|
||||||
|
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
|
||||||
|
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
|
||||||
|
chmod 400 /etc/memgraph/ssl/* || exit 1
|
||||||
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
abort-upgrade|abort-remove|abort-deconfigure)
|
abort-upgrade|abort-remove|abort-deconfigure)
|
||||||
|
@ -71,6 +71,16 @@ chown memgraph:adm /var/log/memgraph || exit 1
|
|||||||
chmod 750 /var/log/memgraph || exit 1
|
chmod 750 /var/log/memgraph || exit 1
|
||||||
# Make examples directory immutable (optional)
|
# Make examples directory immutable (optional)
|
||||||
chattr +i -R /usr/share/memgraph/examples || true
|
chattr +i -R /usr/share/memgraph/examples || true
|
||||||
|
|
||||||
|
# Generate SSL certificates
|
||||||
|
if [ ! -d /etc/memgraph/ssl ]; then
|
||||||
|
mkdir /etc/memgraph/ssl || exit 1
|
||||||
|
openssl req -x509 -newkey rsa:4096 -days 3650 -nodes \
|
||||||
|
-keyout /etc/memgraph/ssl/key.pem -out /etc/memgraph/ssl/cert.pem \
|
||||||
|
-subj "/C=GB/ST=London/L=London/O=Memgraph Ltd./CN=Memgraph DB" || exit 1
|
||||||
|
chown memgraph:memgraph /etc/memgraph/ssl/* || exit 1
|
||||||
|
chmod 400 /etc/memgraph/ssl/* || exit 1
|
||||||
|
fi
|
||||||
@RPM_SYMLINK_POSTINSTALL@
|
@RPM_SYMLINK_POSTINSTALL@
|
||||||
@CPACK_RPM_SPEC_POSTINSTALL@
|
@CPACK_RPM_SPEC_POSTINSTALL@
|
||||||
|
|
||||||
|
@ -8,6 +8,10 @@ add_subdirectory(telemetry)
|
|||||||
# all memgraph src files
|
# all memgraph src files
|
||||||
set(memgraph_src_files
|
set(memgraph_src_files
|
||||||
communication/buffer.cpp
|
communication/buffer.cpp
|
||||||
|
communication/client.cpp
|
||||||
|
communication/context.cpp
|
||||||
|
communication/helpers.cpp
|
||||||
|
communication/init.cpp
|
||||||
communication/bolt/v1/decoder/decoded_value.cpp
|
communication/bolt/v1/decoder/decoded_value.cpp
|
||||||
communication/rpc/client.cpp
|
communication/rpc/client.cpp
|
||||||
communication/rpc/protocol.cpp
|
communication/rpc/protocol.cpp
|
||||||
@ -189,6 +193,7 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type)
|
|||||||
# memgraph_lib depend on these libraries
|
# memgraph_lib depend on these libraries
|
||||||
set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools
|
set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools
|
||||||
antlr_opencypher_parser_lib dl glog gflags capnp kj
|
antlr_opencypher_parser_lib dl glog gflags capnp kj
|
||||||
|
${OPENSSL_LIBRARIES}
|
||||||
${Boost_IOSTREAMS_LIBRARY_RELEASE}
|
${Boost_IOSTREAMS_LIBRARY_RELEASE}
|
||||||
${Boost_SERIALIZATION_LIBRARY_RELEASE}
|
${Boost_SERIALIZATION_LIBRARY_RELEASE}
|
||||||
mg-utils mg-io)
|
mg-utils mg-io)
|
||||||
@ -206,6 +211,7 @@ endif()
|
|||||||
# STATIC library used by memgraph executables
|
# STATIC library used by memgraph executables
|
||||||
add_library(memgraph_lib STATIC ${memgraph_src_files})
|
add_library(memgraph_lib STATIC ${memgraph_src_files})
|
||||||
target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS})
|
target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS})
|
||||||
|
target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR})
|
||||||
add_dependencies(memgraph_lib generate_opencypher_parser)
|
add_dependencies(memgraph_lib generate_opencypher_parser)
|
||||||
add_dependencies(memgraph_lib generate_lcp)
|
add_dependencies(memgraph_lib generate_lcp)
|
||||||
add_dependencies(memgraph_lib generate_capnp)
|
add_dependencies(memgraph_lib generate_capnp)
|
||||||
|
@ -32,9 +32,9 @@ struct QueryData {
|
|||||||
std::map<std::string, DecodedValue> metadata;
|
std::map<std::string, DecodedValue> metadata;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Client {
|
class Client final {
|
||||||
public:
|
public:
|
||||||
Client() {}
|
explicit Client(communication::ClientContext *context) : client_(context) {}
|
||||||
|
|
||||||
Client(const Client &) = delete;
|
Client(const Client &) = delete;
|
||||||
Client(Client &&) = delete;
|
Client(Client &&) = delete;
|
||||||
|
@ -20,7 +20,7 @@ namespace communication {
|
|||||||
* stack where all execution when it is being done is being done on a single
|
* stack where all execution when it is being done is being done on a single
|
||||||
* thread.
|
* thread.
|
||||||
*/
|
*/
|
||||||
class Buffer {
|
class Buffer final {
|
||||||
private:
|
private:
|
||||||
// Initial capacity of the internal buffer.
|
// Initial capacity of the internal buffer.
|
||||||
const size_t kBufferInitialSize = 65536;
|
const size_t kBufferInitialSize = 65536;
|
||||||
@ -28,6 +28,11 @@ class Buffer {
|
|||||||
public:
|
public:
|
||||||
Buffer();
|
Buffer();
|
||||||
|
|
||||||
|
Buffer(const Buffer &) = delete;
|
||||||
|
Buffer(Buffer &&) = delete;
|
||||||
|
Buffer &operator=(const Buffer &) = delete;
|
||||||
|
Buffer &operator=(Buffer &&) = delete;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class provides all functions from the buffer that are needed to allow
|
* This class provides all functions from the buffer that are needed to allow
|
||||||
* reading data from the buffer.
|
* reading data from the buffer.
|
||||||
@ -36,6 +41,11 @@ class Buffer {
|
|||||||
public:
|
public:
|
||||||
ReadEnd(Buffer &buffer);
|
ReadEnd(Buffer &buffer);
|
||||||
|
|
||||||
|
ReadEnd(const ReadEnd &) = delete;
|
||||||
|
ReadEnd(ReadEnd &&) = delete;
|
||||||
|
ReadEnd &operator=(const ReadEnd &) = delete;
|
||||||
|
ReadEnd &operator=(ReadEnd &&) = delete;
|
||||||
|
|
||||||
uint8_t *data();
|
uint8_t *data();
|
||||||
|
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
@ -58,6 +68,11 @@ class Buffer {
|
|||||||
public:
|
public:
|
||||||
WriteEnd(Buffer &buffer);
|
WriteEnd(Buffer &buffer);
|
||||||
|
|
||||||
|
WriteEnd(const WriteEnd &) = delete;
|
||||||
|
WriteEnd(WriteEnd &&) = delete;
|
||||||
|
WriteEnd &operator=(const WriteEnd &) = delete;
|
||||||
|
WriteEnd &operator=(WriteEnd &&) = delete;
|
||||||
|
|
||||||
io::network::StreamBuffer Allocate();
|
io::network::StreamBuffer Allocate();
|
||||||
|
|
||||||
void Written(size_t len);
|
void Written(size_t len);
|
||||||
|
232
src/communication/client.cpp
Normal file
232
src/communication/client.cpp
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include "communication/client.hpp"
|
||||||
|
#include "communication/helpers.hpp"
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
Client::Client(ClientContext *context) : context_(context) {}
|
||||||
|
|
||||||
|
Client::~Client() {
|
||||||
|
Close();
|
||||||
|
ReleaseSslObjects();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Client::Connect(const io::network::Endpoint &endpoint) {
|
||||||
|
// Try to establish a socket connection.
|
||||||
|
if (!socket_.Connect(endpoint)) return false;
|
||||||
|
|
||||||
|
// Enable TCP keep alive for all connections.
|
||||||
|
// Because we manually always set the `have_more` flag to the socket
|
||||||
|
// `Write` call we can disable the Nagle algorithm because we know that we
|
||||||
|
// are always sending optimal packets. Even if we don't send optimal
|
||||||
|
// packets, there will be no delay between packets and throughput won't
|
||||||
|
// suffer.
|
||||||
|
socket_.SetKeepAlive();
|
||||||
|
socket_.SetNoDelay();
|
||||||
|
|
||||||
|
if (context_->use_ssl()) {
|
||||||
|
// Release leftover SSL objects.
|
||||||
|
ReleaseSslObjects();
|
||||||
|
|
||||||
|
// Create a new SSL object that will be used for SSL communication.
|
||||||
|
ssl_ = SSL_new(context_->context());
|
||||||
|
if (ssl_ == nullptr) {
|
||||||
|
LOG(WARNING) << "Couldn't create client SSL object!";
|
||||||
|
socket_.Close();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
|
||||||
|
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
|
||||||
|
// it doesn't need to close the socket when destructing all objects (we
|
||||||
|
// handle that in our socket destructor).
|
||||||
|
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
|
||||||
|
if (bio_ == nullptr) {
|
||||||
|
LOG(WARNING) << "Couldn't create client BIO object!";
|
||||||
|
socket_.Close();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect the BIO object to the SSL object so that OpenSSL knows which
|
||||||
|
// stream it should use for communication. We use the same object for both
|
||||||
|
// the read and write end. This function cannot fail.
|
||||||
|
SSL_set_bio(ssl_, bio_, bio_);
|
||||||
|
|
||||||
|
// Clear all leftover errors.
|
||||||
|
ERR_clear_error();
|
||||||
|
|
||||||
|
// Perform the TLS handshake.
|
||||||
|
auto ret = SSL_connect(ssl_);
|
||||||
|
if (ret != 1) {
|
||||||
|
LOG(WARNING) << "Couldn't connect to SSL server: " << SslGetLastError();
|
||||||
|
socket_.Close();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Client::ErrorStatus() { return socket_.ErrorStatus(); }
|
||||||
|
|
||||||
|
void Client::Shutdown() { socket_.Shutdown(); }
|
||||||
|
|
||||||
|
void Client::Close() {
|
||||||
|
if (ssl_) {
|
||||||
|
// Perform an unidirectional SSL shutdown. That just means that we send
|
||||||
|
// the shutdown message and don't wait or care for the result.
|
||||||
|
SSL_shutdown(ssl_);
|
||||||
|
}
|
||||||
|
socket_.Close();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Client::Read(size_t len) {
|
||||||
|
size_t received = 0;
|
||||||
|
buffer_.write_end().Resize(buffer_.read_end().size() + len);
|
||||||
|
while (received < len) {
|
||||||
|
auto buff = buffer_.write_end().Allocate();
|
||||||
|
if (ssl_) {
|
||||||
|
// We clear errors here to prevent errors piling up in the internal
|
||||||
|
// OpenSSL error queue. To see when could that be an issue read this:
|
||||||
|
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||||
|
ERR_clear_error();
|
||||||
|
|
||||||
|
// Read encrypted data from the socket using OpenSSL.
|
||||||
|
auto got = SSL_read(ssl_, buff.data, len - received);
|
||||||
|
|
||||||
|
// Handle errors that might have occurred.
|
||||||
|
if (got < 0) {
|
||||||
|
auto err = SSL_get_error(ssl_, got);
|
||||||
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
|
// OpenSSL want's to read more data from the socket. We wait for
|
||||||
|
// more data to be ready and retry the call.
|
||||||
|
socket_.WaitForReadyRead();
|
||||||
|
continue;
|
||||||
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
|
// The OpenSSL library probably wants to perform some kind of
|
||||||
|
// handshake so we wait for the socket to become ready for a write
|
||||||
|
// and call the read again.
|
||||||
|
socket_.WaitForReadyWrite();
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
// This is a fatal error.
|
||||||
|
LOG(WARNING) << "Received an unexpected SSL error: " << err;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (got == 0) {
|
||||||
|
// The server closed the connection.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify the buffer that it has new data.
|
||||||
|
buffer_.write_end().Written(got);
|
||||||
|
received += got;
|
||||||
|
} else {
|
||||||
|
// Read raw data from the socket.
|
||||||
|
auto got = socket_.Read(buff.data, len - received);
|
||||||
|
|
||||||
|
if (got <= 0) {
|
||||||
|
// If `read` returns 0 the server has closed the connection. If `read`
|
||||||
|
// returns -1 all of the errors that could be found in `errno` are
|
||||||
|
// fatal errors (because we are using a blocking socket) so return a
|
||||||
|
// read failure.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify the buffer that it has new data.
|
||||||
|
buffer_.write_end().Written(got);
|
||||||
|
received += got;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t *Client::GetData() { return buffer_.read_end().data(); }
|
||||||
|
|
||||||
|
size_t Client::GetDataSize() { return buffer_.read_end().size(); }
|
||||||
|
|
||||||
|
void Client::ShiftData(size_t len) { buffer_.read_end().Shift(len); }
|
||||||
|
|
||||||
|
void Client::ClearData() { buffer_.read_end().Clear(); }
|
||||||
|
|
||||||
|
bool Client::Write(const uint8_t *data, size_t len, bool have_more) {
|
||||||
|
if (ssl_) {
|
||||||
|
// `SSL_write` has the interface of a normal `write` call. Because of that
|
||||||
|
// we need to ensure that all data is written to the socket manually.
|
||||||
|
while (len > 0) {
|
||||||
|
// We clear errors here to prevent errors piling up in the internal
|
||||||
|
// OpenSSL error queue. To see when could that be an issue read this:
|
||||||
|
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||||
|
ERR_clear_error();
|
||||||
|
|
||||||
|
// Write data to the socket using OpenSSL.
|
||||||
|
auto written = SSL_write(ssl_, data, len);
|
||||||
|
if (written < 0) {
|
||||||
|
auto err = SSL_get_error(ssl_, written);
|
||||||
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
|
// OpenSSL wants to perform some kind of handshake, we need to
|
||||||
|
// ensure that there is data available for the next call to
|
||||||
|
// `SSL_write`.
|
||||||
|
socket_.WaitForReadyRead();
|
||||||
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
|
// The socket probably returned WOULDBLOCK and we need to wait for
|
||||||
|
// the output buffers to clear and reattempt the send.
|
||||||
|
socket_.WaitForReadyWrite();
|
||||||
|
} else {
|
||||||
|
// This is a fatal error.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (written == 0) {
|
||||||
|
// The client closed the connection.
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
len -= written;
|
||||||
|
data += written;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return socket_.Write(data, len, have_more);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Client::Write(const std::string &str, bool have_more) {
|
||||||
|
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
|
||||||
|
have_more);
|
||||||
|
}
|
||||||
|
|
||||||
|
const io::network::Endpoint &Client::endpoint() { return socket_.endpoint(); }
|
||||||
|
|
||||||
|
void Client::ReleaseSslObjects() {
|
||||||
|
// If we are using SSL we need to free the allocated objects. Here we only
|
||||||
|
// free the SSL object because the `SSL_free` function also automatically
|
||||||
|
// frees the BIO object.
|
||||||
|
if (ssl_) {
|
||||||
|
SSL_free(ssl_);
|
||||||
|
ssl_ = nullptr;
|
||||||
|
bio_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ClientInputStream::ClientInputStream(Client &client) : client_(client) {}
|
||||||
|
|
||||||
|
uint8_t *ClientInputStream::data() { return client_.GetData(); }
|
||||||
|
|
||||||
|
size_t ClientInputStream::size() const { return client_.GetDataSize(); }
|
||||||
|
|
||||||
|
void ClientInputStream::Shift(size_t len) { client_.ShiftData(len); }
|
||||||
|
|
||||||
|
void ClientInputStream::Clear() { client_.ClearData(); }
|
||||||
|
|
||||||
|
ClientOutputStream::ClientOutputStream(Client &client) : client_(client) {}
|
||||||
|
|
||||||
|
bool ClientOutputStream::Write(const uint8_t *data, size_t len,
|
||||||
|
bool have_more) {
|
||||||
|
return client_.Write(data, len, have_more);
|
||||||
|
}
|
||||||
|
bool ClientOutputStream::Write(const std::string &str, bool have_more) {
|
||||||
|
return client_.Write(str, have_more);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace communication
|
@ -1,6 +1,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <openssl/bio.h>
|
||||||
|
#include <openssl/err.h>
|
||||||
|
#include <openssl/ssl.h>
|
||||||
|
|
||||||
#include "communication/buffer.hpp"
|
#include "communication/buffer.hpp"
|
||||||
|
#include "communication/context.hpp"
|
||||||
|
#include "communication/init.hpp"
|
||||||
#include "io/network/endpoint.hpp"
|
#include "io/network/endpoint.hpp"
|
||||||
#include "io/network/socket.hpp"
|
#include "io/network/socket.hpp"
|
||||||
|
|
||||||
@ -10,106 +16,115 @@ namespace communication {
|
|||||||
* This class implements a generic network Client.
|
* This class implements a generic network Client.
|
||||||
* It uses blocking sockets and provides an API that can be used to receive/send
|
* It uses blocking sockets and provides an API that can be used to receive/send
|
||||||
* data over the network connection.
|
* data over the network connection.
|
||||||
|
*
|
||||||
|
* NOTE: If you use this client you **must** call `communication::Init()` from
|
||||||
|
* the `main` function before using the client!
|
||||||
*/
|
*/
|
||||||
class Client {
|
class Client final {
|
||||||
public:
|
public:
|
||||||
|
explicit Client(ClientContext *context);
|
||||||
|
|
||||||
|
~Client();
|
||||||
|
|
||||||
|
Client(const Client &) = delete;
|
||||||
|
Client(Client &&) = delete;
|
||||||
|
Client &operator=(const Client &) = delete;
|
||||||
|
Client &operator=(Client &&) = delete;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function connects to a remote server and returns whether the connect
|
* This function connects to a remote server and returns whether the connect
|
||||||
* succeeded.
|
* succeeded.
|
||||||
*/
|
*/
|
||||||
bool Connect(const io::network::Endpoint &endpoint) {
|
bool Connect(const io::network::Endpoint &endpoint);
|
||||||
if (!socket_.Connect(endpoint)) return false;
|
|
||||||
socket_.SetKeepAlive();
|
|
||||||
socket_.SetNoDelay();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function returns `true` if the socket is in an error state.
|
* This function returns `true` if the socket is in an error state.
|
||||||
*/
|
*/
|
||||||
bool ErrorStatus() { return socket_.ErrorStatus(); }
|
bool ErrorStatus();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function shuts down the socket.
|
* This function shuts down the socket.
|
||||||
*/
|
*/
|
||||||
void Shutdown() { socket_.Shutdown(); }
|
void Shutdown();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function closes the socket.
|
* This function closes the socket.
|
||||||
*/
|
*/
|
||||||
void Close() { socket_.Close(); }
|
void Close();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function is used to receive `len` bytes from the socket and stores it
|
* This function is used to receive `len` bytes from the socket and stores it
|
||||||
* in an internal buffer. It returns `true` if the read succeeded and `false`
|
* in an internal buffer. It returns `true` if the read succeeded and `false`
|
||||||
* if it didn't.
|
* if it didn't.
|
||||||
*/
|
*/
|
||||||
bool Read(size_t len) {
|
bool Read(size_t len);
|
||||||
size_t received = 0;
|
|
||||||
buffer_.write_end().Resize(buffer_.read_end().size() + len);
|
|
||||||
while (received < len) {
|
|
||||||
auto buff = buffer_.write_end().Allocate();
|
|
||||||
int got = socket_.Read(buff.data, len - received);
|
|
||||||
if (got <= 0) return false;
|
|
||||||
buffer_.write_end().Written(got);
|
|
||||||
received += got;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function returns a pointer to the read data that is currently stored
|
* This function returns a pointer to the read data that is currently stored
|
||||||
* in the client.
|
* in the client.
|
||||||
*/
|
*/
|
||||||
uint8_t *GetData() { return buffer_.read_end().data(); }
|
uint8_t *GetData();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function returns the size of the read data that is currently stored in
|
* This function returns the size of the read data that is currently stored in
|
||||||
* the client.
|
* the client.
|
||||||
*/
|
*/
|
||||||
size_t GetDataSize() { return buffer_.read_end().size(); }
|
size_t GetDataSize();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function removes first `len` bytes from the data buffer.
|
* This function removes first `len` bytes from the data buffer.
|
||||||
*/
|
*/
|
||||||
void ShiftData(size_t len) { buffer_.read_end().Shift(len); }
|
void ShiftData(size_t len);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function clears the data buffer.
|
* This function clears the data buffer.
|
||||||
*/
|
*/
|
||||||
void ClearData() { buffer_.read_end().Clear(); }
|
void ClearData();
|
||||||
|
|
||||||
// Write end
|
/**
|
||||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
* This function writes data to the socket.
|
||||||
return socket_.Write(data, len, have_more);
|
* TODO (mferencevic): the `have_more` flag currently isn't supported when
|
||||||
}
|
* using OpenSSL
|
||||||
bool Write(const std::string &str, bool have_more = false) {
|
*/
|
||||||
return Write(reinterpret_cast<const uint8_t *>(str.data()), str.size(),
|
bool Write(const uint8_t *data, size_t len, bool have_more = false);
|
||||||
have_more);
|
|
||||||
}
|
|
||||||
|
|
||||||
const io::network::Endpoint &endpoint() { return socket_.endpoint(); }
|
/**
|
||||||
|
* This function writes data to the socket.
|
||||||
|
*/
|
||||||
|
bool Write(const std::string &str, bool have_more = false);
|
||||||
|
|
||||||
|
const io::network::Endpoint &endpoint();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
io::network::Socket socket_;
|
void ReleaseSslObjects();
|
||||||
|
|
||||||
|
io::network::Socket socket_;
|
||||||
Buffer buffer_;
|
Buffer buffer_;
|
||||||
|
|
||||||
|
ClientContext *context_;
|
||||||
|
SSL *ssl_{nullptr};
|
||||||
|
BIO *bio_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This class provides a stream-like input side object to the client.
|
* This class provides a stream-like input side object to the client.
|
||||||
*/
|
*/
|
||||||
class ClientInputStream {
|
class ClientInputStream final {
|
||||||
public:
|
public:
|
||||||
ClientInputStream(Client &client) : client_(client) {}
|
ClientInputStream(Client &client);
|
||||||
|
|
||||||
uint8_t *data() { return client_.GetData(); }
|
ClientInputStream(const ClientInputStream &) = delete;
|
||||||
|
ClientInputStream(ClientInputStream &&) = delete;
|
||||||
|
ClientInputStream &operator=(const ClientInputStream &) = delete;
|
||||||
|
ClientInputStream &operator=(ClientInputStream &&) = delete;
|
||||||
|
|
||||||
size_t size() const { return client_.GetDataSize(); }
|
uint8_t *data();
|
||||||
|
|
||||||
void Shift(size_t len) { client_.ShiftData(len); }
|
size_t size() const;
|
||||||
|
|
||||||
void Clear() { client_.ClearData(); }
|
void Shift(size_t len);
|
||||||
|
|
||||||
|
void Clear();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Client &client_;
|
Client &client_;
|
||||||
@ -118,16 +133,18 @@ class ClientInputStream {
|
|||||||
/**
|
/**
|
||||||
* This class provides a stream-like output side object to the client.
|
* This class provides a stream-like output side object to the client.
|
||||||
*/
|
*/
|
||||||
class ClientOutputStream {
|
class ClientOutputStream final {
|
||||||
public:
|
public:
|
||||||
ClientOutputStream(Client &client) : client_(client) {}
|
ClientOutputStream(Client &client);
|
||||||
|
|
||||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
ClientOutputStream(const ClientOutputStream &) = delete;
|
||||||
return client_.Write(data, len, have_more);
|
ClientOutputStream(ClientOutputStream &&) = delete;
|
||||||
}
|
ClientOutputStream &operator=(const ClientOutputStream &) = delete;
|
||||||
bool Write(const std::string &str, bool have_more = false) {
|
ClientOutputStream &operator=(ClientOutputStream &&) = delete;
|
||||||
return client_.Write(str, have_more);
|
|
||||||
}
|
bool Write(const uint8_t *data, size_t len, bool have_more = false);
|
||||||
|
|
||||||
|
bool Write(const std::string &str, bool have_more = false);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Client &client_;
|
Client &client_;
|
||||||
|
80
src/communication/context.cpp
Normal file
80
src/communication/context.cpp
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include "communication/context.hpp"
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
ClientContext::ClientContext(bool use_ssl) : use_ssl_(use_ssl), ctx_(nullptr) {
|
||||||
|
if (use_ssl_) {
|
||||||
|
ctx_ = SSL_CTX_new(TLS_client_method());
|
||||||
|
CHECK(ctx_ != nullptr) << "Couldn't create client SSL_CTX object!";
|
||||||
|
|
||||||
|
// Disable legacy SSL support. Other options can be seen here:
|
||||||
|
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
|
||||||
|
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ClientContext::ClientContext(const std::string &key_file,
|
||||||
|
const std::string &cert_file)
|
||||||
|
: ClientContext(true) {
|
||||||
|
if (key_file != "" && cert_file != "") {
|
||||||
|
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
|
||||||
|
SSL_FILETYPE_PEM) == 1)
|
||||||
|
<< "Couldn't load client certificate from file: " << cert_file;
|
||||||
|
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(),
|
||||||
|
SSL_FILETYPE_PEM) == 1)
|
||||||
|
<< "Couldn't load client private key from file: " << key_file;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SSL_CTX *ClientContext::context() { return ctx_; }
|
||||||
|
|
||||||
|
bool ClientContext::use_ssl() { return use_ssl_; }
|
||||||
|
|
||||||
|
ServerContext::ServerContext() : use_ssl_(false), ctx_(nullptr) {}
|
||||||
|
|
||||||
|
ServerContext::ServerContext(const std::string &key_file,
|
||||||
|
const std::string &cert_file,
|
||||||
|
const std::string &ca_file, bool verify_peer)
|
||||||
|
: use_ssl_(true), ctx_(SSL_CTX_new(TLS_server_method())) {
|
||||||
|
// TODO (mferencevic): add support for encrypted private keys
|
||||||
|
// TODO (mferencevic): add certificate revocation list (CRL)
|
||||||
|
CHECK(SSL_CTX_use_certificate_file(ctx_, cert_file.c_str(),
|
||||||
|
SSL_FILETYPE_PEM) == 1)
|
||||||
|
<< "Couldn't load server certificate from file: " << cert_file;
|
||||||
|
CHECK(SSL_CTX_use_PrivateKey_file(ctx_, key_file.c_str(), SSL_FILETYPE_PEM) ==
|
||||||
|
1)
|
||||||
|
<< "Couldn't load server private key from file: " << key_file;
|
||||||
|
|
||||||
|
// Disable legacy SSL support. Other options can be seen here:
|
||||||
|
// https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_options.html
|
||||||
|
SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3);
|
||||||
|
|
||||||
|
if (ca_file != "") {
|
||||||
|
// Load the certificate authority file.
|
||||||
|
CHECK(SSL_CTX_load_verify_locations(ctx_, ca_file.c_str(), nullptr) == 1)
|
||||||
|
<< "Couldn't load certificate authority from file: " << ca_file;
|
||||||
|
|
||||||
|
if (verify_peer) {
|
||||||
|
// Add the CA to list of accepted CAs that is sent to the client.
|
||||||
|
STACK_OF(X509_NAME) *ca_names = SSL_load_client_CA_file(ca_file.c_str());
|
||||||
|
CHECK(ca_names != nullptr)
|
||||||
|
<< "Couldn't load certificate authority from file: " << ca_file;
|
||||||
|
// `ca_names` doesn' need to be free'd because we pass it to
|
||||||
|
// `SSL_CTX_set_client_CA_list`:
|
||||||
|
// https://mta.openssl.org/pipermail/openssl-users/2015-May/001363.html
|
||||||
|
SSL_CTX_set_client_CA_list(ctx_, ca_names);
|
||||||
|
|
||||||
|
// Enable verification of the client certificate.
|
||||||
|
SSL_CTX_set_verify(
|
||||||
|
ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SSL_CTX *ServerContext::context() { return ctx_; }
|
||||||
|
|
||||||
|
bool ServerContext::use_ssl() { return use_ssl_; }
|
||||||
|
|
||||||
|
} // namespace communication
|
63
src/communication/context.hpp
Normal file
63
src/communication/context.hpp
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <openssl/ssl.h>
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
class ClientContext final {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* This constructor constructs a ClientContext that can either not use SSL
|
||||||
|
* (`use_ssl` is `false` by default), or it constructs a ClientContext that
|
||||||
|
* doesn't use a client certificate when `use_ssl` is set to `true`.
|
||||||
|
*/
|
||||||
|
explicit ClientContext(bool use_ssl = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This constructor constructs a ClientContext that uses SSL and uses the
|
||||||
|
* specific client private key and certificate combination. If the parameters
|
||||||
|
* `key_file` and `cert_file` are equal to "" then the constructor falls back
|
||||||
|
* to the above constructor that uses SSL without certificates.
|
||||||
|
*/
|
||||||
|
ClientContext(const std::string &key_file, const std::string &cert_file);
|
||||||
|
|
||||||
|
SSL_CTX *context();
|
||||||
|
|
||||||
|
bool use_ssl();
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_ssl_;
|
||||||
|
SSL_CTX *ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ServerContext final {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* This constructor constructs a ServerContext that doesn't use SSL.
|
||||||
|
*/
|
||||||
|
ServerContext();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This constructor constructs a ServerContext that uses SSL. The parameters
|
||||||
|
* `key_file` and `cert_file` can't be "" because when setting up a server it
|
||||||
|
* is mandatory to supply a private key and certificate. The parameter
|
||||||
|
* `ca_file` can be "" because SSL doesn't necessarily need to check that the
|
||||||
|
* client has a valid certificate. If you specify `verify_peer` to be `true`
|
||||||
|
* to check that the client certificate is valid, then you need to supply a
|
||||||
|
* valid `ca_file` as well.
|
||||||
|
*/
|
||||||
|
ServerContext(const std::string &key_file, const std::string &cert_file,
|
||||||
|
const std::string &ca_file = "", bool verify_peer = false);
|
||||||
|
|
||||||
|
SSL_CTX *context();
|
||||||
|
|
||||||
|
bool use_ssl();
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_ssl_;
|
||||||
|
SSL_CTX *ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace communication
|
13
src/communication/helpers.cpp
Normal file
13
src/communication/helpers.cpp
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#include <openssl/err.h>
|
||||||
|
|
||||||
|
#include "communication/helpers.hpp"
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
const std::string SslGetLastError() {
|
||||||
|
char buff[2048];
|
||||||
|
auto err = ERR_get_error();
|
||||||
|
ERR_error_string_n(err, buff, sizeof(buff));
|
||||||
|
return std::string(buff);
|
||||||
|
}
|
||||||
|
} // namespace communication
|
12
src/communication/helpers.hpp
Normal file
12
src/communication/helpers.hpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This function reads and returns a string describing the last OpenSSL error.
|
||||||
|
*/
|
||||||
|
const std::string SslGetLastError();
|
||||||
|
|
||||||
|
} // namespace communication
|
21
src/communication/init.cpp
Normal file
21
src/communication/init.cpp
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include <openssl/bio.h>
|
||||||
|
#include <openssl/err.h>
|
||||||
|
#include <openssl/ssl.h>
|
||||||
|
|
||||||
|
#include "utils/signals.hpp"
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
void Init() {
|
||||||
|
// Initialize the OpenSSL library.
|
||||||
|
SSL_library_init();
|
||||||
|
OpenSSL_add_ssl_algorithms();
|
||||||
|
SSL_load_error_strings();
|
||||||
|
ERR_load_crypto_strings();
|
||||||
|
|
||||||
|
// Ignore SIGPIPE.
|
||||||
|
CHECK(utils::SignalIgnore(utils::Signal::Pipe)) << "Couldn't ignore SIGPIPE!";
|
||||||
|
}
|
||||||
|
} // namespace communication
|
16
src/communication/init.hpp
Normal file
16
src/communication/init.hpp
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace communication {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Call this function in each `main` file that uses the Communication stack. It
|
||||||
|
* is used to initialize all libraries (primarily OpenSSL) and to fix some
|
||||||
|
* issues also related to OpenSSL (handling of SIGPIPE).
|
||||||
|
*
|
||||||
|
* Description of OpenSSL init can be seen here:
|
||||||
|
* https://wiki.openssl.org/index.php/Library_Initialization
|
||||||
|
*
|
||||||
|
* NOTE: This function must be called **exactly** once.
|
||||||
|
*/
|
||||||
|
void Init();
|
||||||
|
} // namespace communication
|
@ -13,6 +13,7 @@
|
|||||||
#include "communication/session.hpp"
|
#include "communication/session.hpp"
|
||||||
#include "io/network/epoll.hpp"
|
#include "io/network/epoll.hpp"
|
||||||
#include "io/network/socket.hpp"
|
#include "io/network/socket.hpp"
|
||||||
|
#include "utils/signals.hpp"
|
||||||
#include "utils/thread.hpp"
|
#include "utils/thread.hpp"
|
||||||
#include "utils/thread/sync.hpp"
|
#include "utils/thread/sync.hpp"
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ namespace communication {
|
|||||||
* expired.
|
* expired.
|
||||||
*/
|
*/
|
||||||
template <class TSession, class TSessionData>
|
template <class TSession, class TSessionData>
|
||||||
class Listener {
|
class Listener final {
|
||||||
private:
|
private:
|
||||||
// The maximum number of events handled per execution thread is 1. This is
|
// The maximum number of events handled per execution thread is 1. This is
|
||||||
// because each event represents the start of a network request and it doesn't
|
// because each event represents the start of a network request and it doesn't
|
||||||
@ -39,14 +40,27 @@ class Listener {
|
|||||||
using SessionHandler = Session<TSession, TSessionData>;
|
using SessionHandler = Session<TSession, TSessionData>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Listener(TSessionData &data, int inactivity_timeout_sec,
|
Listener(TSessionData &data, ServerContext *context,
|
||||||
const std::string &service_name)
|
int inactivity_timeout_sec, const std::string &service_name,
|
||||||
|
size_t workers_count)
|
||||||
: data_(data),
|
: data_(data),
|
||||||
alive_(true),
|
alive_(true),
|
||||||
|
context_(context),
|
||||||
inactivity_timeout_sec_(inactivity_timeout_sec),
|
inactivity_timeout_sec_(inactivity_timeout_sec),
|
||||||
service_name_(service_name) {
|
service_name_(service_name) {
|
||||||
|
std::cout << "Starting " << workers_count << " " << service_name_
|
||||||
|
<< " workers" << std::endl;
|
||||||
|
for (size_t i = 0; i < workers_count; ++i) {
|
||||||
|
worker_threads_.emplace_back([this, service_name, i]() {
|
||||||
|
utils::ThreadSetName(fmt::format("{} worker {}", service_name, i + 1));
|
||||||
|
while (alive_) {
|
||||||
|
WaitAndProcessEvents();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
if (inactivity_timeout_sec_ > 0) {
|
if (inactivity_timeout_sec_ > 0) {
|
||||||
thread_ = std::thread([this, service_name]() {
|
timeout_thread_ = std::thread([this, service_name]() {
|
||||||
utils::ThreadSetName(fmt::format("{} timeout", service_name));
|
utils::ThreadSetName(fmt::format("{} timeout", service_name));
|
||||||
while (alive_) {
|
while (alive_) {
|
||||||
{
|
{
|
||||||
@ -72,7 +86,10 @@ class Listener {
|
|||||||
|
|
||||||
~Listener() {
|
~Listener() {
|
||||||
alive_.store(false);
|
alive_.store(false);
|
||||||
if (thread_.joinable()) thread_.join();
|
if (timeout_thread_.joinable()) timeout_thread_.join();
|
||||||
|
for (auto &worker_thread : worker_threads_) {
|
||||||
|
worker_thread.join();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Listener(const Listener &) = delete;
|
Listener(const Listener &) = delete;
|
||||||
@ -88,20 +105,12 @@ class Listener {
|
|||||||
void AddConnection(io::network::Socket &&connection) {
|
void AddConnection(io::network::Socket &&connection) {
|
||||||
std::unique_lock<utils::SpinLock> guard(lock_);
|
std::unique_lock<utils::SpinLock> guard(lock_);
|
||||||
|
|
||||||
// Set connection options.
|
|
||||||
// The socket is left to be a blocking socket, but when `Read` is called
|
|
||||||
// then a flag is manually set to enable non-blocking read that is used in
|
|
||||||
// conjunction with `EPOLLET`. That means that the socket is used in a
|
|
||||||
// non-blocking fashion for reads and a blocking fashion for writes.
|
|
||||||
connection.SetKeepAlive();
|
|
||||||
connection.SetNoDelay();
|
|
||||||
|
|
||||||
// Remember fd before moving connection into Session.
|
// Remember fd before moving connection into Session.
|
||||||
int fd = connection.fd();
|
int fd = connection.fd();
|
||||||
|
|
||||||
// Create a new Session for the connection.
|
// Create a new Session for the connection.
|
||||||
sessions_.push_back(std::make_unique<SessionHandler>(
|
sessions_.push_back(std::make_unique<SessionHandler>(
|
||||||
std::move(connection), data_, inactivity_timeout_sec_));
|
std::move(connection), data_, context_, inactivity_timeout_sec_));
|
||||||
|
|
||||||
// Register the connection in Epoll.
|
// Register the connection in Epoll.
|
||||||
// We want to listen to an incoming event which is edge triggered and
|
// We want to listen to an incoming event which is edge triggered and
|
||||||
@ -113,6 +122,7 @@ class Listener {
|
|||||||
sessions_.back().get());
|
sessions_.back().get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
/**
|
/**
|
||||||
* This function polls the event queue and processes incoming data.
|
* This function polls the event queue and processes incoming data.
|
||||||
* It is thread safe and is intended to be called from multiple threads and
|
* It is thread safe and is intended to be called from multiple threads and
|
||||||
@ -168,7 +178,6 @@ class Listener {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
bool ExecuteSession(SessionHandler &session) {
|
bool ExecuteSession(SessionHandler &session) {
|
||||||
try {
|
try {
|
||||||
if (session.Execute()) {
|
if (session.Execute()) {
|
||||||
@ -223,8 +232,11 @@ class Listener {
|
|||||||
utils::SpinLock lock_;
|
utils::SpinLock lock_;
|
||||||
std::vector<std::unique_ptr<SessionHandler>> sessions_;
|
std::vector<std::unique_ptr<SessionHandler>> sessions_;
|
||||||
|
|
||||||
std::thread thread_;
|
std::thread timeout_thread_;
|
||||||
|
std::vector<std::thread> worker_threads_;
|
||||||
std::atomic<bool> alive_;
|
std::atomic<bool> alive_;
|
||||||
|
|
||||||
|
ServerContext *context_;
|
||||||
const int inactivity_timeout_sec_;
|
const int inactivity_timeout_sec_;
|
||||||
const std::string service_name_;
|
const std::string service_name_;
|
||||||
};
|
};
|
||||||
|
@ -30,7 +30,7 @@ std::experimental::optional<::capnp::FlatArrayMessageReader> Client::Send(
|
|||||||
|
|
||||||
// Connect to the remote server.
|
// Connect to the remote server.
|
||||||
if (!client_) {
|
if (!client_) {
|
||||||
client_.emplace();
|
client_.emplace(&context_);
|
||||||
if (!client_->Connect(endpoint_)) {
|
if (!client_->Connect(endpoint_)) {
|
||||||
LOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
|
LOG(ERROR) << "Couldn't connect to remote address " << endpoint_;
|
||||||
client_ = std::experimental::nullopt;
|
client_ = std::experimental::nullopt;
|
||||||
|
@ -86,6 +86,8 @@ class Client {
|
|||||||
::capnp::MessageBuilder *message);
|
::capnp::MessageBuilder *message);
|
||||||
|
|
||||||
io::network::Endpoint endpoint_;
|
io::network::Endpoint endpoint_;
|
||||||
|
// TODO (mferencevic): currently the RPC client is hardcoded not to use SSL
|
||||||
|
communication::ClientContext context_;
|
||||||
std::experimental::optional<communication::Client> client_;
|
std::experimental::optional<communication::Client> client_;
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
|
@ -4,7 +4,7 @@ namespace communication::rpc {
|
|||||||
|
|
||||||
Server::Server(const io::network::Endpoint &endpoint,
|
Server::Server(const io::network::Endpoint &endpoint,
|
||||||
size_t workers_count)
|
size_t workers_count)
|
||||||
: server_(endpoint, *this, -1, "RPC", workers_count) {}
|
: server_(endpoint, *this, &context_, -1, "RPC", workers_count) {}
|
||||||
|
|
||||||
void Server::StopProcessingCalls() {
|
void Server::StopProcessingCalls() {
|
||||||
server_.Shutdown();
|
server_.Shutdown();
|
||||||
|
@ -78,6 +78,8 @@ class Server {
|
|||||||
ConcurrentMap<uint64_t, RpcCallback> callbacks_;
|
ConcurrentMap<uint64_t, RpcCallback> callbacks_;
|
||||||
|
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
|
// TODO (mferencevic): currently the RPC server is hardcoded not to use SSL
|
||||||
|
communication::ServerContext context_;
|
||||||
communication::Server<Session, Server> server_;
|
communication::Server<Session, Server> server_;
|
||||||
}; // namespace communication::rpc
|
}; // namespace communication::rpc
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
#include <glog/logging.h>
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include "communication/init.hpp"
|
||||||
#include "communication/listener.hpp"
|
#include "communication/listener.hpp"
|
||||||
#include "io/network/socket.hpp"
|
#include "io/network/socket.hpp"
|
||||||
#include "utils/thread.hpp"
|
#include "utils/thread.hpp"
|
||||||
@ -27,6 +28,9 @@ namespace communication {
|
|||||||
* Current Server achitecture:
|
* Current Server achitecture:
|
||||||
* incoming connection -> server -> listener -> session
|
* incoming connection -> server -> listener -> session
|
||||||
*
|
*
|
||||||
|
* NOTE: If you use this server you **must** call `communication::Init()` from
|
||||||
|
* the `main` function before using the server!
|
||||||
|
*
|
||||||
* @tparam TSession the server can handle different Sessions, each session
|
* @tparam TSession the server can handle different Sessions, each session
|
||||||
* represents a different protocol so the same network infrastructure
|
* represents a different protocol so the same network infrastructure
|
||||||
* can be used for handling different protocols
|
* can be used for handling different protocols
|
||||||
@ -34,7 +38,7 @@ namespace communication {
|
|||||||
* session
|
* session
|
||||||
*/
|
*/
|
||||||
template <typename TSession, typename TSessionData>
|
template <typename TSession, typename TSessionData>
|
||||||
class Server {
|
class Server final {
|
||||||
public:
|
public:
|
||||||
using Socket = io::network::Socket;
|
using Socket = io::network::Socket;
|
||||||
|
|
||||||
@ -43,9 +47,11 @@ class Server {
|
|||||||
* invokes workers_count workers
|
* invokes workers_count workers
|
||||||
*/
|
*/
|
||||||
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
|
Server(const io::network::Endpoint &endpoint, TSessionData &session_data,
|
||||||
int inactivity_timeout_sec, const std::string &service_name,
|
ServerContext *context, int inactivity_timeout_sec,
|
||||||
|
const std::string &service_name,
|
||||||
size_t workers_count = std::thread::hardware_concurrency())
|
size_t workers_count = std::thread::hardware_concurrency())
|
||||||
: listener_(session_data, inactivity_timeout_sec, service_name),
|
: listener_(session_data, context, inactivity_timeout_sec, service_name,
|
||||||
|
workers_count),
|
||||||
service_name_(service_name) {
|
service_name_(service_name) {
|
||||||
// Without server we can't continue with application so we can just
|
// Without server we can't continue with application so we can just
|
||||||
// terminate here.
|
// terminate here.
|
||||||
@ -58,18 +64,7 @@ class Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
thread_ = std::thread([this, workers_count, service_name]() {
|
thread_ = std::thread([this, workers_count, service_name]() {
|
||||||
std::cout << "Starting " << workers_count << " " << service_name
|
|
||||||
<< " workers" << std::endl;
|
|
||||||
utils::ThreadSetName(fmt::format("{} server", service_name));
|
utils::ThreadSetName(fmt::format("{} server", service_name));
|
||||||
for (size_t i = 0; i < workers_count; ++i) {
|
|
||||||
worker_threads_.emplace_back([this, service_name, i]() {
|
|
||||||
utils::ThreadSetName(
|
|
||||||
fmt::format("{} worker {}", service_name, i + 1));
|
|
||||||
while (alive_) {
|
|
||||||
listener_.WaitAndProcessEvents();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << service_name << " server is fully armed and operational"
|
std::cout << service_name << " server is fully armed and operational"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
@ -81,9 +76,6 @@ class Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::cout << service_name << " shutting down..." << std::endl;
|
std::cout << service_name << " shutting down..." << std::endl;
|
||||||
for (auto &worker_thread : worker_threads_) {
|
|
||||||
worker_thread.join();
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,7 +120,6 @@ class Server {
|
|||||||
|
|
||||||
std::atomic<bool> alive_{true};
|
std::atomic<bool> alive_{true};
|
||||||
std::thread thread_;
|
std::thread thread_;
|
||||||
std::vector<std::thread> worker_threads_;
|
|
||||||
|
|
||||||
Socket socket_;
|
Socket socket_;
|
||||||
Listener<TSession, TSessionData> listener_;
|
Listener<TSession, TSessionData> listener_;
|
||||||
|
@ -9,7 +9,13 @@
|
|||||||
|
|
||||||
#include <glog/logging.h>
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include <openssl/bio.h>
|
||||||
|
#include <openssl/err.h>
|
||||||
|
#include <openssl/ssl.h>
|
||||||
|
|
||||||
#include "communication/buffer.hpp"
|
#include "communication/buffer.hpp"
|
||||||
|
#include "communication/context.hpp"
|
||||||
|
#include "communication/helpers.hpp"
|
||||||
#include "io/network/socket.hpp"
|
#include "io/network/socket.hpp"
|
||||||
#include "io/network/stream_buffer.hpp"
|
#include "io/network/stream_buffer.hpp"
|
||||||
#include "utils/exceptions.hpp"
|
#include "utils/exceptions.hpp"
|
||||||
@ -35,12 +41,19 @@ using InputStream = Buffer::ReadEnd;
|
|||||||
* This is used to provide output from user sessions. All sessions used with the
|
* This is used to provide output from user sessions. All sessions used with the
|
||||||
* network stack should use this class for their output stream.
|
* network stack should use this class for their output stream.
|
||||||
*/
|
*/
|
||||||
class OutputStream {
|
class OutputStream final {
|
||||||
public:
|
public:
|
||||||
OutputStream(io::network::Socket &socket) : socket_(socket) {}
|
OutputStream(
|
||||||
|
std::function<bool(const uint8_t *, size_t, bool)> write_function)
|
||||||
|
: write_function_(write_function) {}
|
||||||
|
|
||||||
|
OutputStream(const OutputStream &) = delete;
|
||||||
|
OutputStream(OutputStream &&) = delete;
|
||||||
|
OutputStream &operator=(const OutputStream &) = delete;
|
||||||
|
OutputStream &operator=(OutputStream &&) = delete;
|
||||||
|
|
||||||
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||||
return socket_.Write(data, len, have_more);
|
return write_function_(data, len, have_more);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Write(const std::string &str, bool have_more = false) {
|
bool Write(const std::string &str, bool have_more = false) {
|
||||||
@ -49,7 +62,7 @@ class OutputStream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
io::network::Socket &socket_;
|
std::function<bool(const uint8_t *, size_t, bool)> write_function_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -58,20 +71,73 @@ class OutputStream {
|
|||||||
* wrapping.
|
* wrapping.
|
||||||
*/
|
*/
|
||||||
template <class TSession, class TSessionData>
|
template <class TSession, class TSessionData>
|
||||||
class Session {
|
class Session final {
|
||||||
public:
|
public:
|
||||||
Session(io::network::Socket &&socket, TSessionData &data,
|
Session(io::network::Socket &&socket, TSessionData &data,
|
||||||
int inactivity_timeout_sec)
|
ServerContext *context, int inactivity_timeout_sec)
|
||||||
: socket_(std::move(socket)),
|
: socket_(std::move(socket)),
|
||||||
output_stream_(socket_),
|
output_stream_([this](const uint8_t *data, size_t len, bool have_more) {
|
||||||
|
return Write(data, len, have_more);
|
||||||
|
}),
|
||||||
session_(data, input_buffer_.read_end(), output_stream_),
|
session_(data, input_buffer_.read_end(), output_stream_),
|
||||||
inactivity_timeout_sec_(inactivity_timeout_sec) {}
|
inactivity_timeout_sec_(inactivity_timeout_sec) {
|
||||||
|
// Set socket options.
|
||||||
|
// The socket is set to be a non-blocking socket. We use the socket in a
|
||||||
|
// non-blocking fashion for reads and manually simulate a blocking socket
|
||||||
|
// type for writes. This manual handling of writes is necessary because
|
||||||
|
// OpenSSL doesn't provide a way to add `recv` parameters to the `SSL_read`
|
||||||
|
// call so we can't have a blocking socket and use it in a non-blocking way
|
||||||
|
// only for reads.
|
||||||
|
// Keep alive is enabled so that the Kernel's TCP stack notifies us if a
|
||||||
|
// connection is broken and shouldn't be used anymore.
|
||||||
|
// Because we manually always set the `have_more` flag to the socket
|
||||||
|
// `Write` call we can disable the Nagle algorithm because we know that we
|
||||||
|
// are always sending optimal packets. Even if we don't send optimal
|
||||||
|
// packets, there will be no delay between packets and throughput won't
|
||||||
|
// suffer.
|
||||||
|
socket_.SetNonBlocking();
|
||||||
|
socket_.SetKeepAlive();
|
||||||
|
socket_.SetNoDelay();
|
||||||
|
|
||||||
|
// Prepare SSL if we should be using it.
|
||||||
|
if (context->use_ssl()) {
|
||||||
|
// Create a new SSL object that will be used for SSL communication.
|
||||||
|
ssl_ = SSL_new(context->context());
|
||||||
|
CHECK(ssl_ != nullptr) << "Couldn't create server SSL object!";
|
||||||
|
|
||||||
|
// Create a new BIO (block I/O) SSL object so that OpenSSL can communicate
|
||||||
|
// using our socket. We specify `BIO_NOCLOSE` to indicate to OpenSSL that
|
||||||
|
// it doesn't need to close the socket when destructing all objects (we
|
||||||
|
// handle that in our socket destructor).
|
||||||
|
bio_ = BIO_new_socket(socket_.fd(), BIO_NOCLOSE);
|
||||||
|
CHECK(bio_ != nullptr) << "Couldn't create server BIO object!";
|
||||||
|
|
||||||
|
// Connect the BIO object to the SSL object so that OpenSSL knows which
|
||||||
|
// stream it should use for communication. We use the same object for both
|
||||||
|
// the read and write end. This function cannot fail.
|
||||||
|
SSL_set_bio(ssl_, bio_, bio_);
|
||||||
|
|
||||||
|
// Indicate to OpenSSL that this connection is a server. The TLS handshake
|
||||||
|
// will be performed in the first `SSL_read` or `SSL_write` call. This
|
||||||
|
// function cannot fail.
|
||||||
|
SSL_set_accept_state(ssl_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Session(const Session &) = delete;
|
Session(const Session &) = delete;
|
||||||
Session(Session &&) = delete;
|
Session(Session &&) = delete;
|
||||||
Session &operator=(const Session &) = delete;
|
Session &operator=(const Session &) = delete;
|
||||||
Session &operator=(Session &&) = delete;
|
Session &operator=(Session &&) = delete;
|
||||||
|
|
||||||
|
~Session() {
|
||||||
|
// If we are using SSL we need to free the allocated objects. Here we only
|
||||||
|
// free the SSL object because the `SSL_free` function also automatically
|
||||||
|
// frees the BIO object.
|
||||||
|
if (ssl_) {
|
||||||
|
SSL_free(ssl_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This function is called from the communication stack when an event occurs
|
* This function is called from the communication stack when an event occurs
|
||||||
* indicating that there is data waiting to be read. This function calls the
|
* indicating that there is data waiting to be read. This function calls the
|
||||||
@ -87,29 +153,69 @@ class Session {
|
|||||||
|
|
||||||
// Allocate the buffer to fill the data.
|
// Allocate the buffer to fill the data.
|
||||||
auto buf = input_buffer_.write_end().Allocate();
|
auto buf = input_buffer_.write_end().Allocate();
|
||||||
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
|
|
||||||
int len = socket_.Read(buf.data, buf.len, true);
|
|
||||||
|
|
||||||
// Check for read errors.
|
if (ssl_) {
|
||||||
if (len == -1) {
|
// We clear errors here to prevent errors piling up in the internal
|
||||||
// This means read would block or read was interrupted by signal, we
|
// OpenSSL error queue. To see when could that be an issue read this:
|
||||||
// return `true` to indicate that all data is processad and to stop
|
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||||
// reading of data.
|
ERR_clear_error();
|
||||||
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
|
|
||||||
return true;
|
// Read data from the socket using the OpenSSL API.
|
||||||
|
auto len = SSL_read(ssl_, buf.data, buf.len);
|
||||||
|
|
||||||
|
// Check for read errors.
|
||||||
|
if (len < 0) {
|
||||||
|
auto err = SSL_get_error(ssl_, len);
|
||||||
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
|
// OpenSSL want's to read more data from the socket. We return `true`
|
||||||
|
// to stop execution of the session to wait for more data to be
|
||||||
|
// received.
|
||||||
|
return true;
|
||||||
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
|
// The OpenSSL library wants to perfrom some kind of handshake so we
|
||||||
|
// wait for the socket to become ready for a write and call the read
|
||||||
|
// again. We return `false` so that the listener calls this function
|
||||||
|
// again.
|
||||||
|
socket_.WaitForReadyWrite();
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
// This is a fatal error.
|
||||||
|
throw utils::BasicException(SslGetLastError());
|
||||||
|
}
|
||||||
|
} else if (len == 0) {
|
||||||
|
// The client closed the connection.
|
||||||
|
throw SessionClosedException("Session was closed by the client.");
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
// Notify the input buffer that it has new data.
|
||||||
|
input_buffer_.write_end().Written(len);
|
||||||
}
|
}
|
||||||
// Some other error occurred, throw an exception to start session cleanup.
|
} else {
|
||||||
throw utils::BasicException("Couldn't read data from socket!");
|
// Read from the buffer at most buf.len bytes in a non-blocking fashion.
|
||||||
}
|
// Note, the `true` parameter for non-blocking here is redundant because
|
||||||
|
// the socket already is non-blocking.
|
||||||
|
auto len = socket_.Read(buf.data, buf.len, true);
|
||||||
|
|
||||||
// The client has closed the connection.
|
// Check for read errors.
|
||||||
if (len == 0) {
|
if (len == -1) {
|
||||||
throw SessionClosedException("Session was closed by client.");
|
// This means read would block or read was interrupted by signal, we
|
||||||
|
// return `true` to indicate that all data is processad and to stop
|
||||||
|
// reading of data.
|
||||||
|
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Some other error occurred, throw an exception to start session
|
||||||
|
// cleanup.
|
||||||
|
throw utils::BasicException("Couldn't read data from the socket!");
|
||||||
|
} else if (len == 0) {
|
||||||
|
// The client has closed the connection.
|
||||||
|
throw SessionClosedException("Session was closed by client.");
|
||||||
|
} else {
|
||||||
|
// Notify the input buffer that it has new data.
|
||||||
|
input_buffer_.write_end().Written(len);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Notify the input buffer that it has new data.
|
|
||||||
input_buffer_.write_end().Written(len);
|
|
||||||
|
|
||||||
// Execute the session.
|
// Execute the session.
|
||||||
session_.Execute();
|
session_.Execute();
|
||||||
|
|
||||||
@ -142,6 +248,52 @@ class Session {
|
|||||||
last_event_time_ = std::chrono::steady_clock::now();
|
last_event_time_ = std::chrono::steady_clock::now();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (mferencevic): the `have_more` flag currently isn't supported
|
||||||
|
// when using OpenSSL
|
||||||
|
bool Write(const uint8_t *data, size_t len, bool have_more = false) {
|
||||||
|
if (ssl_) {
|
||||||
|
// `SSL_write` has the interface of a normal `write` call. Because of that
|
||||||
|
// we need to ensure that all data is written to the socket manually.
|
||||||
|
while (len > 0) {
|
||||||
|
// We clear errors here to prevent errors piling up in the internal
|
||||||
|
// OpenSSL error queue. To see when could that be an issue read this:
|
||||||
|
// https://www.arangodb.com/2014/07/started-hate-openssl/
|
||||||
|
ERR_clear_error();
|
||||||
|
|
||||||
|
// Write data to the socket using OpenSSL.
|
||||||
|
auto written = SSL_write(ssl_, data, len);
|
||||||
|
if (written < 0) {
|
||||||
|
auto err = SSL_get_error(ssl_, written);
|
||||||
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
|
// OpenSSL wants to perform some kind of handshake, we need to
|
||||||
|
// ensure that there is data available for the next call to
|
||||||
|
// `SSL_write`.
|
||||||
|
socket_.WaitForReadyRead();
|
||||||
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
|
// The socket probably returned WOULDBLOCK and we need to wait for
|
||||||
|
// the output buffers to clear and reattempt the send.
|
||||||
|
socket_.WaitForReadyWrite();
|
||||||
|
} else {
|
||||||
|
// This is a fatal error.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (written == 0) {
|
||||||
|
// The client closed the connection.
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
len -= written;
|
||||||
|
data += written;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// This function guarantees that all data will be written to the socket
|
||||||
|
// even if the socket is non-blocking. It will use a non-busy wait to send
|
||||||
|
// all data.
|
||||||
|
return socket_.Write(data, len, have_more);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// We own the socket.
|
// We own the socket.
|
||||||
io::network::Socket socket_;
|
io::network::Socket socket_;
|
||||||
|
|
||||||
@ -157,5 +309,9 @@ class Session {
|
|||||||
std::chrono::steady_clock::now()};
|
std::chrono::steady_clock::now()};
|
||||||
utils::SpinLock lock_;
|
utils::SpinLock lock_;
|
||||||
const int inactivity_timeout_sec_;
|
const int inactivity_timeout_sec_;
|
||||||
};
|
|
||||||
|
// SSL objects.
|
||||||
|
SSL *ssl_{nullptr};
|
||||||
|
BIO *bio_{nullptr};
|
||||||
|
}; // namespace communication
|
||||||
} // namespace communication
|
} // namespace communication
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
#include <netinet/in.h>
|
#include <netinet/in.h>
|
||||||
#include <netinet/tcp.h>
|
#include <netinet/tcp.h>
|
||||||
|
#include <poll.h>
|
||||||
#include <sys/epoll.h>
|
#include <sys/epoll.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
@ -219,8 +220,13 @@ bool Socket::Write(const uint8_t *data, size_t len, bool have_more) {
|
|||||||
// Terminal error, return failure.
|
// Terminal error, return failure.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Non-fatal error, retry.
|
// Non-fatal error, retry after the socket is ready. This is here to
|
||||||
continue;
|
// implement a non-busy wait. If we just continue with the loop we have a
|
||||||
|
// busy wait.
|
||||||
|
if (!WaitForReadyWrite()) return false;
|
||||||
|
} else if (written == 0) {
|
||||||
|
// The client closed the connection.
|
||||||
|
return false;
|
||||||
} else {
|
} else {
|
||||||
len -= written;
|
len -= written;
|
||||||
data += written;
|
data += written;
|
||||||
@ -234,7 +240,32 @@ bool Socket::Write(const std::string &s, bool have_more) {
|
|||||||
have_more);
|
have_more);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Socket::Read(void *buffer, size_t len, bool nonblock) {
|
ssize_t Socket::Read(void *buffer, size_t len, bool nonblock) {
|
||||||
return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0);
|
return recv(socket_, buffer, len, nonblock ? MSG_DONTWAIT : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Socket::WaitForReadyRead() {
|
||||||
|
struct pollfd p;
|
||||||
|
p.fd = socket_;
|
||||||
|
p.events = POLLIN;
|
||||||
|
// We call poll with one element in the poll fds array (first and second
|
||||||
|
// arguments), also we set the timeout to -1 to block indefinitely until an
|
||||||
|
// event occurs.
|
||||||
|
int ret = poll(&p, 1, -1);
|
||||||
|
if (ret < 1) return false;
|
||||||
|
return p.revents & POLLIN;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Socket::WaitForReadyWrite() {
|
||||||
|
struct pollfd p;
|
||||||
|
p.fd = socket_;
|
||||||
|
p.events = POLLOUT;
|
||||||
|
// We call poll with one element in the poll fds array (first and second
|
||||||
|
// arguments), also we set the timeout to -1 to block indefinitely until an
|
||||||
|
// event occurs.
|
||||||
|
int ret = poll(&p, 1, -1);
|
||||||
|
if (ret < 1) return false;
|
||||||
|
return p.revents & POLLOUT;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace io::network
|
} // namespace io::network
|
||||||
|
@ -153,7 +153,41 @@ class Socket {
|
|||||||
* == 0 if the client closed the connection
|
* == 0 if the client closed the connection
|
||||||
* < 0 if an error has occurred
|
* < 0 if an error has occurred
|
||||||
*/
|
*/
|
||||||
int Read(void *buffer, size_t len, bool nonblock = false);
|
ssize_t Read(void *buffer, size_t len, bool nonblock = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wait until the socket becomes ready for a `Read` operation.
|
||||||
|
* This function blocks indefinitely waiting for the socket to change its
|
||||||
|
* state. This function is useful when you need a blocking operation on a
|
||||||
|
* non-blocking socket, you can call this function to ensure that your next
|
||||||
|
* `Read` operation will succeed.
|
||||||
|
*
|
||||||
|
* The function returns `true` if the wait succeded (there is data waiting to
|
||||||
|
* be read from the socket) and returns `false` if the wait failed (the socket
|
||||||
|
* was closed or something else bad happened).
|
||||||
|
*
|
||||||
|
* @return wait success status:
|
||||||
|
* true if the wait succeeded
|
||||||
|
* false if the wait failed
|
||||||
|
*/
|
||||||
|
bool WaitForReadyRead();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wait until the socket becomes ready for a `Write` operation.
|
||||||
|
* This function blocks indefinitely waiting for the socket to change its
|
||||||
|
* state. This function is useful when you need a blocking operation on a
|
||||||
|
* non-blocking socket, you can call this function to ensure that your next
|
||||||
|
* `Write` operation will succeed.
|
||||||
|
*
|
||||||
|
* The function returns `true` if the wait succeded (the socket can be written
|
||||||
|
* to) and returns `false` if the wait failed (the socket was closed or
|
||||||
|
* something else bad happened).
|
||||||
|
*
|
||||||
|
* @return wait success status:
|
||||||
|
* true if the wait succeeded
|
||||||
|
* false if the wait failed
|
||||||
|
*/
|
||||||
|
bool WaitForReadyWrite();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}
|
Socket(int fd, const Endpoint &endpoint) : socket_(fd), endpoint_(endpoint) {}
|
||||||
|
@ -28,6 +28,7 @@ using communication::bolt::SessionData;
|
|||||||
using SessionT = communication::bolt::Session<communication::InputStream,
|
using SessionT = communication::bolt::Session<communication::InputStream,
|
||||||
communication::OutputStream>;
|
communication::OutputStream>;
|
||||||
using ServerT = communication::Server<SessionT, SessionData>;
|
using ServerT = communication::Server<SessionT, SessionData>;
|
||||||
|
using communication::ServerContext;
|
||||||
|
|
||||||
// General purpose flags.
|
// General purpose flags.
|
||||||
DEFINE_string(interface, "0.0.0.0",
|
DEFINE_string(interface, "0.0.0.0",
|
||||||
@ -41,6 +42,8 @@ DEFINE_VALIDATED_int32(session_inactivity_timeout, 1800,
|
|||||||
"Time in seconds after which inactive sessions will be "
|
"Time in seconds after which inactive sessions will be "
|
||||||
"closed.",
|
"closed.",
|
||||||
FLAG_IN_RANGE(1, INT32_MAX));
|
FLAG_IN_RANGE(1, INT32_MAX));
|
||||||
|
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||||
|
DEFINE_string(key_file, "", "Key file to use.");
|
||||||
DEFINE_string(log_file, "", "Path to where the log should be stored.");
|
DEFINE_string(log_file, "", "Path to where the log should be stored.");
|
||||||
DEFINE_HIDDEN_string(
|
DEFINE_HIDDEN_string(
|
||||||
log_link_basename, "",
|
log_link_basename, "",
|
||||||
@ -142,6 +145,9 @@ int WithInit(int argc, char **argv,
|
|||||||
stats::InitStatsLogging(get_stats_prefix());
|
stats::InitStatsLogging(get_stats_prefix());
|
||||||
utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); });
|
utils::OnScopeExit stop_stats([] { stats::StopStatsLogging(); });
|
||||||
|
|
||||||
|
// Initialize the communication library.
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
// Start memory warning logger.
|
// Start memory warning logger.
|
||||||
utils::Scheduler mem_log_scheduler;
|
utils::Scheduler mem_log_scheduler;
|
||||||
if (FLAGS_memory_warning_threshold > 0) {
|
if (FLAGS_memory_warning_threshold > 0) {
|
||||||
@ -160,9 +166,17 @@ void SingleNodeMain() {
|
|||||||
google::SetUsageMessage("Memgraph single-node database server");
|
google::SetUsageMessage("Memgraph single-node database server");
|
||||||
database::SingleNode db;
|
database::SingleNode db;
|
||||||
SessionData session_data{db};
|
SessionData session_data{db};
|
||||||
|
|
||||||
|
ServerContext context;
|
||||||
|
std::string service_name = "Bolt";
|
||||||
|
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
|
||||||
|
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
|
||||||
|
service_name = "BoltS";
|
||||||
|
}
|
||||||
|
|
||||||
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
||||||
session_data, FLAGS_session_inactivity_timeout, "Bolt",
|
session_data, &context, FLAGS_session_inactivity_timeout,
|
||||||
FLAGS_num_workers);
|
service_name, FLAGS_num_workers);
|
||||||
|
|
||||||
// Setup telemetry
|
// Setup telemetry
|
||||||
std::experimental::optional<telemetry::Telemetry> telemetry;
|
std::experimental::optional<telemetry::Telemetry> telemetry;
|
||||||
@ -214,9 +228,17 @@ void MasterMain() {
|
|||||||
|
|
||||||
database::Master db;
|
database::Master db;
|
||||||
SessionData session_data{db};
|
SessionData session_data{db};
|
||||||
|
|
||||||
|
ServerContext context;
|
||||||
|
std::string service_name = "Bolt";
|
||||||
|
if (FLAGS_key_file != "" && FLAGS_cert_file != "") {
|
||||||
|
context = ServerContext(FLAGS_key_file, FLAGS_cert_file);
|
||||||
|
service_name = "BoltS";
|
||||||
|
}
|
||||||
|
|
||||||
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
ServerT server({FLAGS_interface, static_cast<uint16_t>(FLAGS_port)},
|
||||||
session_data, FLAGS_session_inactivity_timeout, "Bolt",
|
session_data, &context, FLAGS_session_inactivity_timeout,
|
||||||
FLAGS_num_workers);
|
service_name, FLAGS_num_workers);
|
||||||
|
|
||||||
// Handler for regular termination signals
|
// Handler for regular termination signals
|
||||||
auto shutdown = [&server] {
|
auto shutdown = [&server] {
|
||||||
|
@ -42,10 +42,11 @@ class TestSession {
|
|||||||
input_stream_.Shift(size + 2);
|
input_stream_.Shift(size + 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
communication::InputStream input_stream_;
|
communication::InputStream &input_stream_;
|
||||||
communication::OutputStream output_stream_;
|
communication::OutputStream &output_stream_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using ContextT = communication::ServerContext;
|
||||||
using ServerT = communication::Server<TestSession, TestData>;
|
using ServerT = communication::Server<TestSession, TestData>;
|
||||||
|
|
||||||
void client_run(int num, const char *interface, uint16_t port,
|
void client_run(int num, const char *interface, uint16_t port,
|
||||||
|
@ -32,8 +32,8 @@ class TestSession {
|
|||||||
output_stream_.Write(input_stream_.data(), input_stream_.size());
|
output_stream_.Write(input_stream_.data(), input_stream_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
communication::InputStream input_stream_;
|
communication::InputStream &input_stream_;
|
||||||
communication::OutputStream output_stream_;
|
communication::OutputStream &output_stream_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::atomic<bool> run{true};
|
std::atomic<bool> run{true};
|
||||||
@ -63,8 +63,9 @@ TEST(Network, SocketReadHangOnConcurrentConnections) {
|
|||||||
TestData data;
|
TestData data;
|
||||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||||
int Nc = N * 3;
|
int Nc = N * 3;
|
||||||
communication::Server<TestSession, TestData> server(endpoint, data, -1,
|
communication::ServerContext context;
|
||||||
"Test", N);
|
communication::Server<TestSession, TestData> server(endpoint, data, &context,
|
||||||
|
-1, "Test", N);
|
||||||
|
|
||||||
const auto &ep = server.endpoint();
|
const auto &ep = server.endpoint();
|
||||||
// start clients
|
// start clients
|
||||||
|
@ -21,7 +21,8 @@ TEST(Network, Server) {
|
|||||||
// initialize server
|
// initialize server
|
||||||
TestData session_data;
|
TestData session_data;
|
||||||
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
int N = (std::thread::hardware_concurrency() + 1) / 2;
|
||||||
ServerT server(endpoint, session_data, -1, "Test", N);
|
ContextT context;
|
||||||
|
ServerT server(endpoint, session_data, &context, -1, "Test", N);
|
||||||
|
|
||||||
const auto &ep = server.endpoint();
|
const auto &ep = server.endpoint();
|
||||||
// start clients
|
// start clients
|
||||||
|
@ -22,7 +22,8 @@ TEST(Network, SessionLeak) {
|
|||||||
|
|
||||||
// initialize server
|
// initialize server
|
||||||
TestData session_data;
|
TestData session_data;
|
||||||
ServerT server(endpoint, session_data, -1, "Test", 2);
|
ContextT context;
|
||||||
|
ServerT server(endpoint, session_data, &context, -1, "Test", 2);
|
||||||
|
|
||||||
// start clients
|
// start clients
|
||||||
int N = 50;
|
int N = 50;
|
||||||
|
@ -1,2 +1,5 @@
|
|||||||
# telemetry test binaries
|
# telemetry test binaries
|
||||||
add_subdirectory(telemetry)
|
add_subdirectory(telemetry)
|
||||||
|
|
||||||
|
# ssl test binaries
|
||||||
|
add_subdirectory(ssl)
|
||||||
|
@ -6,3 +6,11 @@
|
|||||||
- server.py # server script
|
- server.py # server script
|
||||||
- ../../../build_debug/tests/integration/telemetry/client # client binary
|
- ../../../build_debug/tests/integration/telemetry/client # client binary
|
||||||
- ../../../build_debug/tests/manual/kvstore_console # kvstore console binary
|
- ../../../build_debug/tests/manual/kvstore_console # kvstore console binary
|
||||||
|
|
||||||
|
- name: integration__ssl
|
||||||
|
cd: ssl
|
||||||
|
commands: ./runner.sh
|
||||||
|
infiles:
|
||||||
|
- runner.sh # runner script
|
||||||
|
- ../../../build_debug/tests/integration/ssl/tester # tester binary
|
||||||
|
enable_network: true
|
||||||
|
6
tests/integration/ssl/CMakeLists.txt
Normal file
6
tests/integration/ssl/CMakeLists.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
set(target_name memgraph__integration__ssl)
|
||||||
|
set(tester_target_name ${target_name}__tester)
|
||||||
|
|
||||||
|
add_executable(${tester_target_name} tester.cpp)
|
||||||
|
set_target_properties(${tester_target_name} PROPERTIES OUTPUT_NAME tester)
|
||||||
|
target_link_libraries(${tester_target_name} memgraph_lib kvstore_dummy_lib)
|
78
tests/integration/ssl/runner.sh
Executable file
78
tests/integration/ssl/runner.sh
Executable file
@ -0,0 +1,78 @@
|
|||||||
|
#!/bin/bash -e
|
||||||
|
|
||||||
|
pushd () { command pushd "$@" > /dev/null; }
|
||||||
|
popd () { command popd "$@" > /dev/null; }
|
||||||
|
die () { printf "\033[1;31m~~ Test failed! ~~\033[0m\n\n"; exit 1; }
|
||||||
|
|
||||||
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
cd "$DIR"
|
||||||
|
|
||||||
|
tmpdir=/tmp/memgraph_integration_ssl
|
||||||
|
|
||||||
|
if [ -d $tmpdir ]; then
|
||||||
|
rm -rf $tmpdir
|
||||||
|
fi
|
||||||
|
mkdir -p $tmpdir
|
||||||
|
cd $tmpdir
|
||||||
|
|
||||||
|
easyrsa="EasyRSA-3.0.4"
|
||||||
|
wget -nv http://deps.memgraph.io/$easyrsa.tgz
|
||||||
|
|
||||||
|
tar -xf $easyrsa.tgz
|
||||||
|
mv $easyrsa ca1
|
||||||
|
cp -r ca1 ca2
|
||||||
|
|
||||||
|
for i in ca1 ca2; do
|
||||||
|
pushd $i
|
||||||
|
./easyrsa --batch init-pki
|
||||||
|
./easyrsa --batch build-ca nopass
|
||||||
|
./easyrsa --batch build-server-full server nopass
|
||||||
|
./easyrsa --batch build-client-full client nopass
|
||||||
|
popd
|
||||||
|
done
|
||||||
|
|
||||||
|
binary_dir="$DIR/../../../build"
|
||||||
|
if [ ! -d $binary_dir ]; then
|
||||||
|
binary_dir="$DIR/../../../build_debug"
|
||||||
|
fi
|
||||||
|
|
||||||
|
set +e
|
||||||
|
|
||||||
|
echo
|
||||||
|
for i in ca1 ca2; do
|
||||||
|
for j in none ca1 ca2; do
|
||||||
|
for k in false true; do
|
||||||
|
printf "\033[1;36m~~ Server CA: $i; Client CA: $j; Verify peer: $k ~~\033[0m\n"
|
||||||
|
|
||||||
|
if [ "$j" == "none" ]; then
|
||||||
|
client_key=""
|
||||||
|
client_cert=""
|
||||||
|
else
|
||||||
|
client_key=$j/pki/private/client.key
|
||||||
|
client_cert=$j/pki/issued/client.crt
|
||||||
|
fi
|
||||||
|
|
||||||
|
$binary_dir/tests/integration/ssl/tester \
|
||||||
|
--server-key-file=$i/pki/private/server.key \
|
||||||
|
--server-cert-file=$i/pki/issued/server.crt \
|
||||||
|
--server-ca-file=$i/pki/ca.crt \
|
||||||
|
--server-verify-peer=$k \
|
||||||
|
--client-key-file=$client_key \
|
||||||
|
--client-cert-file=$client_cert
|
||||||
|
|
||||||
|
exitcode=$?
|
||||||
|
|
||||||
|
if [ "$i" == "$j" ]; then
|
||||||
|
[ $exitcode -eq 0 ] || die
|
||||||
|
else
|
||||||
|
if $k; then
|
||||||
|
[ $exitcode -ne 0 ] || die
|
||||||
|
else
|
||||||
|
[ $exitcode -eq 0 ] || die
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
printf "\033[1;32m~~ Test ok! ~~\033[0m\n\n"
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
76
tests/integration/ssl/tester.cpp
Normal file
76
tests/integration/ssl/tester.cpp
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include "communication/client.hpp"
|
||||||
|
#include "communication/server.hpp"
|
||||||
|
#include "utils/exceptions.hpp"
|
||||||
|
|
||||||
|
DEFINE_string(server_cert_file, "", "Server certificate file to use.");
|
||||||
|
DEFINE_string(server_key_file, "", "Server key file to use.");
|
||||||
|
DEFINE_string(server_ca_file, "", "Server CA file to use.");
|
||||||
|
DEFINE_bool(server_verify_peer, false, "Set to true to verify the peer.");
|
||||||
|
|
||||||
|
DEFINE_string(client_cert_file, "", "Client certificate file to use.");
|
||||||
|
DEFINE_string(client_key_file, "", "Client key file to use.");
|
||||||
|
|
||||||
|
const std::string message = "ssl echo test";
|
||||||
|
|
||||||
|
struct EchoData {};
|
||||||
|
|
||||||
|
class EchoSession {
|
||||||
|
public:
|
||||||
|
EchoSession(EchoData &, communication::InputStream &input_stream,
|
||||||
|
communication::OutputStream &output_stream)
|
||||||
|
: input_stream_(input_stream), output_stream_(output_stream) {}
|
||||||
|
|
||||||
|
void Execute() {
|
||||||
|
if (input_stream_.size() < message.size()) return;
|
||||||
|
LOG(INFO) << "Server received message.";
|
||||||
|
if (!output_stream_.Write(input_stream_.data(), message.size())) {
|
||||||
|
throw utils::BasicException("Output stream write failed!");
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Server sent message.";
|
||||||
|
input_stream_.Shift(message.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
communication::InputStream &input_stream_;
|
||||||
|
communication::OutputStream &output_stream_;
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
// Initialize the communication stack.
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
|
// Initialize the server.
|
||||||
|
EchoData echo_data;
|
||||||
|
communication::ServerContext server_context(
|
||||||
|
FLAGS_server_key_file, FLAGS_server_cert_file, FLAGS_server_ca_file,
|
||||||
|
FLAGS_server_verify_peer);
|
||||||
|
communication::Server<EchoSession, EchoData> server(
|
||||||
|
{"127.0.0.1", 0}, echo_data, &server_context, -1, "SSL", 1);
|
||||||
|
|
||||||
|
// Initialize the client.
|
||||||
|
communication::ClientContext client_context(FLAGS_client_key_file,
|
||||||
|
FLAGS_client_cert_file);
|
||||||
|
communication::Client client(&client_context);
|
||||||
|
|
||||||
|
// Connect to the server.
|
||||||
|
CHECK(client.Connect(server.endpoint())) << "Couldn't connect to server!";
|
||||||
|
|
||||||
|
// Perform echo.
|
||||||
|
CHECK(client.Write(message)) << "Client couldn't send message!";
|
||||||
|
LOG(INFO) << "Client sent message.";
|
||||||
|
CHECK(client.Read(message.size())) << "Client couldn't receive message!";
|
||||||
|
LOG(INFO) << "Client received message.";
|
||||||
|
CHECK(std::string(reinterpret_cast<const char *>(client.GetData()),
|
||||||
|
message.size()) == message)
|
||||||
|
<< "Received message isn't equal to sent message!";
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
@ -10,6 +10,7 @@
|
|||||||
#include "io/network/endpoint.hpp"
|
#include "io/network/endpoint.hpp"
|
||||||
|
|
||||||
using EndpointT = io::network::Endpoint;
|
using EndpointT = io::network::Endpoint;
|
||||||
|
using ContextT = communication::ClientContext;
|
||||||
using ClientT = communication::bolt::Client;
|
using ClientT = communication::bolt::Client;
|
||||||
using QueryDataT = communication::bolt::QueryData;
|
using QueryDataT = communication::bolt::QueryData;
|
||||||
using communication::bolt::DecodedValue;
|
using communication::bolt::DecodedValue;
|
||||||
@ -18,23 +19,23 @@ class BoltClient {
|
|||||||
public:
|
public:
|
||||||
BoltClient(const std::string &address, uint16_t port,
|
BoltClient(const std::string &address, uint16_t port,
|
||||||
const std::string &username, const std::string &password,
|
const std::string &username, const std::string &password,
|
||||||
const std::string & = "") {
|
const std::string & = "", bool use_ssl = false)
|
||||||
|
: context_(use_ssl), client_(context_) {
|
||||||
EndpointT endpoint(address, port);
|
EndpointT endpoint(address, port);
|
||||||
client_ = std::make_unique<ClientT>();
|
|
||||||
|
|
||||||
if (!client_->Connect(endpoint, username, password)) {
|
if (!client_.Connect(endpoint, username, password)) {
|
||||||
LOG(FATAL) << "Could not connect to: " << endpoint;
|
LOG(FATAL) << "Could not connect to: " << endpoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryDataT Execute(const std::string &query,
|
QueryDataT Execute(const std::string &query,
|
||||||
const std::map<std::string, DecodedValue> ¶meters) {
|
const std::map<std::string, DecodedValue> ¶meters) {
|
||||||
return client_->Execute(query, parameters);
|
return client_.Execute(query, parameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Close() { client_->Close(); }
|
void Close() { client_.Close(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<ClientT> client_;
|
ContextT context_;
|
||||||
|
ClientT client_;
|
||||||
};
|
};
|
||||||
|
@ -331,11 +331,14 @@ int main(int argc, char **argv) {
|
|||||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
google::InitGoogleLogging(argv[0]);
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
stats::InitStatsLogging(
|
stats::InitStatsLogging(
|
||||||
fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario));
|
fmt::format("client.long_running.{}.{}", FLAGS_group, FLAGS_scenario));
|
||||||
|
|
||||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||||
Client client;
|
ClientContext context(FLAGS_use_ssl);
|
||||||
|
Client client(&context);
|
||||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "utils/exceptions.hpp"
|
#include "utils/exceptions.hpp"
|
||||||
#include "utils/timer.hpp"
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
using communication::ClientContext;
|
||||||
using communication::bolt::Client;
|
using communication::bolt::Client;
|
||||||
using communication::bolt::DecodedValue;
|
using communication::bolt::DecodedValue;
|
||||||
using io::network::Endpoint;
|
using io::network::Endpoint;
|
||||||
|
@ -16,6 +16,7 @@ DEFINE_int32(num_workers, 1, "Number of workers");
|
|||||||
DEFINE_string(output, "", "Output file");
|
DEFINE_string(output, "", "Output file");
|
||||||
DEFINE_string(username, "", "Username for the database");
|
DEFINE_string(username, "", "Username for the database");
|
||||||
DEFINE_string(password, "", "Password for the database");
|
DEFINE_string(password, "", "Password for the database");
|
||||||
|
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||||
DEFINE_int32(duration, 30, "Number of seconds to execute benchmark");
|
DEFINE_int32(duration, 30, "Number of seconds to execute benchmark");
|
||||||
|
|
||||||
DEFINE_string(group, "unknown", "Test group name");
|
DEFINE_string(group, "unknown", "Test group name");
|
||||||
@ -97,7 +98,8 @@ class TestClient {
|
|||||||
std::thread runner_thread_;
|
std::thread runner_thread_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Client client_;
|
communication::ClientContext context_{FLAGS_use_ssl};
|
||||||
|
Client client_{&context_};
|
||||||
};
|
};
|
||||||
|
|
||||||
void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) {
|
void RunMultithreadedTest(std::vector<std::unique_ptr<TestClient>> &clients) {
|
||||||
|
@ -271,12 +271,15 @@ int main(int argc, char **argv) {
|
|||||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
google::InitGoogleLogging(argv[0]);
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
nlohmann::json config;
|
nlohmann::json config;
|
||||||
std::cin >> config;
|
std::cin >> config;
|
||||||
|
|
||||||
auto independent_nodes_ids = [&] {
|
auto independent_nodes_ids = [&] {
|
||||||
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
Endpoint endpoint(io::network::ResolveHostname(FLAGS_address), FLAGS_port);
|
||||||
Client client;
|
ClientContext context(FLAGS_use_ssl);
|
||||||
|
Client client(&context);
|
||||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
|
|||||||
DEFINE_int32(port, 7687, "Server port");
|
DEFINE_int32(port, 7687, "Server port");
|
||||||
DEFINE_string(username, "", "Username for the database");
|
DEFINE_string(username, "", "Username for the database");
|
||||||
DEFINE_string(password, "", "Password for the database");
|
DEFINE_string(password, "", "Password for the database");
|
||||||
|
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||||
|
|
||||||
using communication::bolt::DecodedValue;
|
using communication::bolt::DecodedValue;
|
||||||
|
|
||||||
@ -58,7 +59,8 @@ void ExecuteQueries(const std::vector<std::string> &queries,
|
|||||||
for (int i = 0; i < FLAGS_num_workers; ++i) {
|
for (int i = 0; i < FLAGS_num_workers; ++i) {
|
||||||
threads.push_back(std::thread([&]() {
|
threads.push_back(std::thread([&]() {
|
||||||
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||||
Client client;
|
ClientContext context(FLAGS_use_ssl);
|
||||||
|
Client client(&context);
|
||||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||||
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
LOG(FATAL) << "Couldn't connect to " << endpoint;
|
||||||
}
|
}
|
||||||
@ -100,6 +102,8 @@ int main(int argc, char **argv) {
|
|||||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
google::InitGoogleLogging(argv[0]);
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
std::ifstream ifile;
|
std::ifstream ifile;
|
||||||
std::istream *istream{&std::cin};
|
std::istream *istream{&std::cin};
|
||||||
|
|
||||||
|
@ -71,6 +71,12 @@ target_link_libraries(${test_prefix}sl_position_and_count memgraph_lib kvstore_d
|
|||||||
add_manual_test(stripped_timing.cpp)
|
add_manual_test(stripped_timing.cpp)
|
||||||
target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib)
|
target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib)
|
||||||
|
|
||||||
|
add_manual_test(ssl_client.cpp)
|
||||||
|
target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib)
|
||||||
|
|
||||||
|
add_manual_test(ssl_server.cpp)
|
||||||
|
target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib)
|
||||||
|
|
||||||
add_manual_test(xorshift.cpp)
|
add_manual_test(xorshift.cpp)
|
||||||
target_link_libraries(${test_prefix}xorshift mg-utils)
|
target_link_libraries(${test_prefix}xorshift mg-utils)
|
||||||
|
|
||||||
|
@ -10,15 +10,20 @@ DEFINE_string(address, "127.0.0.1", "Server address");
|
|||||||
DEFINE_int32(port, 7687, "Server port");
|
DEFINE_int32(port, 7687, "Server port");
|
||||||
DEFINE_string(username, "", "Username for the database");
|
DEFINE_string(username, "", "Username for the database");
|
||||||
DEFINE_string(password, "", "Password for the database");
|
DEFINE_string(password, "", "Password for the database");
|
||||||
|
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||||
|
|
||||||
int main(int argc, char **argv) {
|
int main(int argc, char **argv) {
|
||||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
google::InitGoogleLogging(argv[0]);
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
// TODO: handle endpoint exception
|
// TODO: handle endpoint exception
|
||||||
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
|
io::network::Endpoint endpoint(io::network::ResolveHostname(FLAGS_address),
|
||||||
FLAGS_port);
|
FLAGS_port);
|
||||||
communication::bolt::Client client;
|
|
||||||
|
communication::ClientContext context(FLAGS_use_ssl);
|
||||||
|
communication::bolt::Client client(&context);
|
||||||
|
|
||||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;
|
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) return 1;
|
||||||
|
|
||||||
|
65
tests/manual/ssl_client.cpp
Normal file
65
tests/manual/ssl_client.cpp
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include "communication/client.hpp"
|
||||||
|
#include "io/network/endpoint.hpp"
|
||||||
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||||
|
DEFINE_int32(port, 54321, "Server port");
|
||||||
|
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||||
|
DEFINE_string(key_file, "", "Key file to use.");
|
||||||
|
|
||||||
|
bool EchoMessage(communication::Client &client, const std::string &data) {
|
||||||
|
uint16_t size = data.size();
|
||||||
|
if (!client.Write(reinterpret_cast<const uint8_t *>(&size), sizeof(size))) {
|
||||||
|
LOG(WARNING) << "Couldn't send data size!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!client.Write(data)) {
|
||||||
|
LOG(WARNING) << "Couldn't send data!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
client.ClearData();
|
||||||
|
if (!client.Read(size)) {
|
||||||
|
LOG(WARNING) << "Couldn't receive data!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (std::string(reinterpret_cast<const char *>(client.GetData()), size) !=
|
||||||
|
data) {
|
||||||
|
LOG(WARNING) << "Received data isn't equal to sent data!";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
|
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||||
|
|
||||||
|
communication::ClientContext context(FLAGS_key_file, FLAGS_cert_file);
|
||||||
|
communication::Client client(&context);
|
||||||
|
|
||||||
|
if (!client.Connect(endpoint)) return 1;
|
||||||
|
|
||||||
|
bool success = true;
|
||||||
|
while (true) {
|
||||||
|
std::string s;
|
||||||
|
std::getline(std::cin, s);
|
||||||
|
if (s == "") break;
|
||||||
|
if (!EchoMessage(client, s)) {
|
||||||
|
success = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send server shutdown signal. The call will fail, we don't care.
|
||||||
|
EchoMessage(client, "");
|
||||||
|
|
||||||
|
return success ? 0 : 1;
|
||||||
|
}
|
73
tests/manual/ssl_server.cpp
Normal file
73
tests/manual/ssl_server.cpp
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include <gflags/gflags.h>
|
||||||
|
#include <glog/logging.h>
|
||||||
|
|
||||||
|
#include "communication/server.hpp"
|
||||||
|
#include "utils/exceptions.hpp"
|
||||||
|
|
||||||
|
DEFINE_string(address, "127.0.0.1", "Server address");
|
||||||
|
DEFINE_int32(port, 54321, "Server port");
|
||||||
|
DEFINE_string(cert_file, "", "Certificate file to use.");
|
||||||
|
DEFINE_string(key_file, "", "Key file to use.");
|
||||||
|
DEFINE_string(ca_file, "", "CA file to use.");
|
||||||
|
DEFINE_bool(verify_peer, false, "Set to true to verify the peer.");
|
||||||
|
|
||||||
|
struct EchoData {
|
||||||
|
std::atomic<bool> alive{true};
|
||||||
|
};
|
||||||
|
|
||||||
|
class EchoSession {
|
||||||
|
public:
|
||||||
|
EchoSession(EchoData &data, communication::InputStream &input_stream,
|
||||||
|
communication::OutputStream &output_stream)
|
||||||
|
: data_(data),
|
||||||
|
input_stream_(input_stream),
|
||||||
|
output_stream_(output_stream) {}
|
||||||
|
|
||||||
|
void Execute() {
|
||||||
|
if (input_stream_.size() < 2) return;
|
||||||
|
const uint8_t *data = input_stream_.data();
|
||||||
|
uint16_t size = *reinterpret_cast<const uint16_t *>(input_stream_.data());
|
||||||
|
input_stream_.Resize(size + 2);
|
||||||
|
if (input_stream_.size() < size + 2) return;
|
||||||
|
if (size == 0) {
|
||||||
|
LOG(INFO) << "Server received EOF message";
|
||||||
|
data_.alive.store(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LOG(INFO) << "Server received '"
|
||||||
|
<< std::string(reinterpret_cast<const char *>(data + 2), size)
|
||||||
|
<< "'";
|
||||||
|
if (!output_stream_.Write(data + 2, size)) {
|
||||||
|
throw utils::BasicException("Output stream write failed!");
|
||||||
|
}
|
||||||
|
input_stream_.Shift(size + 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
EchoData &data_;
|
||||||
|
communication::InputStream &input_stream_;
|
||||||
|
communication::OutputStream &output_stream_;
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
|
// Initialize the server.
|
||||||
|
EchoData echo_data;
|
||||||
|
io::network::Endpoint endpoint(FLAGS_address, FLAGS_port);
|
||||||
|
communication::ServerContext context(FLAGS_key_file, FLAGS_cert_file,
|
||||||
|
FLAGS_ca_file, FLAGS_verify_peer);
|
||||||
|
communication::Server<EchoSession, EchoData> server(endpoint, echo_data,
|
||||||
|
&context, -1, "SSL", 1);
|
||||||
|
|
||||||
|
while (echo_data.alive) {
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
1
tests/stress/.gitignore
vendored
1
tests/stress/.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
.long_running_stats
|
.long_running_stats
|
||||||
|
*.pem
|
||||||
|
@ -10,6 +10,10 @@
|
|||||||
commands: TIMEOUT=600 ./continuous_integration --properties-on-disk
|
commands: TIMEOUT=600 ./continuous_integration --properties-on-disk
|
||||||
infiles: *STRESS_INFILES
|
infiles: *STRESS_INFILES
|
||||||
|
|
||||||
|
- name: stress_ssl
|
||||||
|
commands: TIMEOUT=600 ./continuous_integration --use-ssl
|
||||||
|
infiles: *STRESS_INFILES
|
||||||
|
|
||||||
- name: stress_large
|
- name: stress_large
|
||||||
project: release
|
project: release
|
||||||
commands: TIMEOUT=43200 ./continuous_integration --large-dataset
|
commands: TIMEOUT=43200 ./continuous_integration --large-dataset
|
||||||
|
@ -146,14 +146,14 @@ def connection_argument_parser():
|
|||||||
'''
|
'''
|
||||||
parser = ArgumentParser(description=__doc__)
|
parser = ArgumentParser(description=__doc__)
|
||||||
|
|
||||||
parser.add_argument('--endpoint', type=str, default='localhost:7687',
|
parser.add_argument('--endpoint', type=str, default='127.0.0.1:7687',
|
||||||
help='DBMS instance endpoint. '
|
help='DBMS instance endpoint. '
|
||||||
'Bolt protocol is the only option.')
|
'Bolt protocol is the only option.')
|
||||||
parser.add_argument('--username', type=str, default='neo4j',
|
parser.add_argument('--username', type=str, default='neo4j',
|
||||||
help='DBMS instance username.')
|
help='DBMS instance username.')
|
||||||
parser.add_argument('--password', type=int, default='1234',
|
parser.add_argument('--password', type=int, default='1234',
|
||||||
help='DBMS instance password.')
|
help='DBMS instance password.')
|
||||||
parser.add_argument('--ssl-enabled', action='store_true',
|
parser.add_argument('--use-ssl', action='store_true',
|
||||||
help="Is SSL enabled?")
|
help="Is SSL enabled?")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ def bolt_session(url, auth, ssl=False):
|
|||||||
'''
|
'''
|
||||||
with wrapper around Bolt session.
|
with wrapper around Bolt session.
|
||||||
|
|
||||||
:param url: str, e.g. "bolt://localhost:7687"
|
:param url: str, e.g. "bolt://127.0.0.1:7687"
|
||||||
:param auth: auth method, goes directly to the Bolt driver constructor
|
:param auth: auth method, goes directly to the Bolt driver constructor
|
||||||
:param ssl: bool, is ssl enabled
|
:param ssl: bool, is ssl enabled
|
||||||
'''
|
'''
|
||||||
@ -183,14 +183,15 @@ def argument_session(args):
|
|||||||
:return: Bolt session context manager based on program arguments
|
:return: Bolt session context manager based on program arguments
|
||||||
'''
|
'''
|
||||||
return bolt_session('bolt://' + args.endpoint,
|
return bolt_session('bolt://' + args.endpoint,
|
||||||
(args.username, str(args.password)))
|
(args.username, str(args.password)),
|
||||||
|
args.use_ssl)
|
||||||
|
|
||||||
|
|
||||||
def argument_driver(args, ssl=False):
|
def argument_driver(args):
|
||||||
return GraphDatabase.driver(
|
return GraphDatabase.driver(
|
||||||
'bolt://' + args.endpoint,
|
'bolt://' + args.endpoint,
|
||||||
auth=(args.username, str(args.password)),
|
auth=(args.username, str(args.password)),
|
||||||
encrypted=ssl)
|
encrypted=args.use_ssl)
|
||||||
|
|
||||||
# This class is used to create and cache sessions. Session is cached by args
|
# This class is used to create and cache sessions. Session is cached by args
|
||||||
# used to create it and process' pid in which it was created. This makes it easy
|
# used to create it and process' pid in which it was created. This makes it easy
|
||||||
|
@ -73,6 +73,8 @@ BASE_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", ".."))
|
|||||||
BUILD_DIR = os.path.join(BASE_DIR, "build")
|
BUILD_DIR = os.path.join(BASE_DIR, "build")
|
||||||
CONFIG_DIR = os.path.join(BASE_DIR, "config")
|
CONFIG_DIR = os.path.join(BASE_DIR, "config")
|
||||||
MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements")
|
MEASUREMENTS_FILE = os.path.join(SCRIPT_DIR, ".apollo_measurements")
|
||||||
|
KEY_FILE = os.path.join(SCRIPT_DIR, ".key.pem")
|
||||||
|
CERT_FILE = os.path.join(SCRIPT_DIR, ".cert.pem")
|
||||||
|
|
||||||
# long running stats file
|
# long running stats file
|
||||||
STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats")
|
STATS_FILE = os.path.join(SCRIPT_DIR, ".long_running_stats")
|
||||||
@ -132,6 +134,8 @@ parser.add_argument("--python", default = os.path.join(SCRIPT_DIR,
|
|||||||
"ve3", "bin", "python3"), type = str)
|
"ve3", "bin", "python3"), type = str)
|
||||||
parser.add_argument("--large-dataset", action = "store_const",
|
parser.add_argument("--large-dataset", action = "store_const",
|
||||||
const = True, default = False)
|
const = True, default = False)
|
||||||
|
parser.add_argument("--use-ssl", action = "store_const",
|
||||||
|
const = True, default = False)
|
||||||
parser.add_argument("--verbose", action = "store_const",
|
parser.add_argument("--verbose", action = "store_const",
|
||||||
const = True, default = False)
|
const = True, default = False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -140,6 +144,14 @@ args = parser.parse_args()
|
|||||||
if not os.path.exists(args.memgraph):
|
if not os.path.exists(args.memgraph):
|
||||||
args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph")
|
args.memgraph = os.path.join(BASE_DIR, "build_release", "memgraph")
|
||||||
|
|
||||||
|
# generate temporary SSL certs
|
||||||
|
if args.use_ssl:
|
||||||
|
# https://unix.stackexchange.com/questions/104171/create-ssl-certificate-non-interactively
|
||||||
|
subj = "/C=HR/ST=Zagreb/L=Zagreb/O=Memgraph/CN=db.memgraph.com"
|
||||||
|
subprocess.run(["openssl", "req", "-new", "-newkey", "rsa:4096",
|
||||||
|
"-days", "365", "-nodes", "-x509", "-subj", subj,
|
||||||
|
"-keyout", KEY_FILE, "-out", CERT_FILE], check=True)
|
||||||
|
|
||||||
# start memgraph
|
# start memgraph
|
||||||
cwd = os.path.dirname(args.memgraph)
|
cwd = os.path.dirname(args.memgraph)
|
||||||
cmd = [args.memgraph, "--num-workers=" + str(THREADS)]
|
cmd = [args.memgraph, "--num-workers=" + str(THREADS)]
|
||||||
@ -151,6 +163,8 @@ if args.durability_directory:
|
|||||||
cmd += ["--durability-directory", args.durability_directory]
|
cmd += ["--durability-directory", args.durability_directory]
|
||||||
if args.properties_on_disk:
|
if args.properties_on_disk:
|
||||||
cmd += ["--properties-on-disk", "id,x"]
|
cmd += ["--properties-on-disk", "id,x"]
|
||||||
|
if args.use_ssl:
|
||||||
|
cmd += ["--cert-file", CERT_FILE, "--key-file", KEY_FILE]
|
||||||
proc_mg = subprocess.Popen(cmd, cwd = cwd,
|
proc_mg = subprocess.Popen(cmd, cwd = cwd,
|
||||||
env = {"MEMGRAPH_CONFIG": args.config})
|
env = {"MEMGRAPH_CONFIG": args.config})
|
||||||
time.sleep(1.0)
|
time.sleep(1.0)
|
||||||
@ -167,6 +181,8 @@ def cleanup():
|
|||||||
runtimes = {}
|
runtimes = {}
|
||||||
dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET
|
dataset = LARGE_DATASET if args.large_dataset else SMALL_DATASET
|
||||||
for test in dataset:
|
for test in dataset:
|
||||||
|
if args.use_ssl:
|
||||||
|
test["options"] += ["--use-ssl"]
|
||||||
runtime = run_test(args, **test)
|
runtime = run_test(args, **test)
|
||||||
runtimes[os.path.splitext(test["test"])[0]] = runtime
|
runtimes[os.path.splitext(test["test"])[0]] = runtime
|
||||||
|
|
||||||
@ -176,6 +192,11 @@ ret_mg = proc_mg.wait()
|
|||||||
if ret_mg != 0:
|
if ret_mg != 0:
|
||||||
raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg))
|
raise Exception("Memgraph binary returned non-zero ({})!".format(ret_mg))
|
||||||
|
|
||||||
|
# cleanup certificates
|
||||||
|
if args.use_ssl:
|
||||||
|
os.remove(KEY_FILE)
|
||||||
|
os.remove(CERT_FILE)
|
||||||
|
|
||||||
# measurements
|
# measurements
|
||||||
measurements = ""
|
measurements = ""
|
||||||
for key, value in runtimes.items():
|
for key, value in runtimes.items():
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include "utils/timer.hpp"
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
using EndpointT = io::network::Endpoint;
|
using EndpointT = io::network::Endpoint;
|
||||||
|
using ClientContextT = communication::ClientContext;
|
||||||
using ClientT = communication::bolt::Client;
|
using ClientT = communication::bolt::Client;
|
||||||
using DecodedValueT = communication::bolt::DecodedValue;
|
using DecodedValueT = communication::bolt::DecodedValue;
|
||||||
using QueryDataT = communication::bolt::QueryData;
|
using QueryDataT = communication::bolt::QueryData;
|
||||||
@ -17,6 +18,7 @@ DEFINE_string(address, "127.0.0.1", "Server address");
|
|||||||
DEFINE_int32(port, 7687, "Server port");
|
DEFINE_int32(port, 7687, "Server port");
|
||||||
DEFINE_string(username, "", "Username for the database");
|
DEFINE_string(username, "", "Username for the database");
|
||||||
DEFINE_string(password, "", "Password for the database");
|
DEFINE_string(password, "", "Password for the database");
|
||||||
|
DEFINE_bool(use_ssl, false, "Set to true to connect with SSL to the server.");
|
||||||
|
|
||||||
DEFINE_int32(vertex_count, 0,
|
DEFINE_int32(vertex_count, 0,
|
||||||
"The average number of vertices in the graph per worker");
|
"The average number of vertices in the graph per worker");
|
||||||
@ -51,7 +53,7 @@ class GraphSession {
|
|||||||
}
|
}
|
||||||
|
|
||||||
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
||||||
client_ = std::make_unique<ClientT>();
|
client_ = std::make_unique<ClientT>(&context_);
|
||||||
|
|
||||||
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
if (!client_->Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||||
throw utils::BasicException("Couldn't connect to server!");
|
throw utils::BasicException("Couldn't connect to server!");
|
||||||
@ -60,6 +62,7 @@ class GraphSession {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
uint64_t id_;
|
uint64_t id_;
|
||||||
|
ClientContextT context_{FLAGS_use_ssl};
|
||||||
std::unique_ptr<ClientT> client_;
|
std::unique_ptr<ClientT> client_;
|
||||||
|
|
||||||
std::set<uint64_t> vertices_;
|
std::set<uint64_t> vertices_;
|
||||||
@ -362,6 +365,8 @@ int main(int argc, char **argv) {
|
|||||||
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
gflags::ParseCommandLineFlags(&argc, &argv, true);
|
||||||
google::InitGoogleLogging(argv[0]);
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
communication::Init();
|
||||||
|
|
||||||
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
|
CHECK(FLAGS_vertex_count > 0) << "Vertex count must be greater than 0!";
|
||||||
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";
|
CHECK(FLAGS_edge_count > 0) << "Edge count must be greater than 0!";
|
||||||
|
|
||||||
@ -369,7 +374,8 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
// create client
|
// create client
|
||||||
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
EndpointT endpoint(FLAGS_address, FLAGS_port);
|
||||||
ClientT client;
|
ClientContextT context(FLAGS_use_ssl);
|
||||||
|
ClientT client(&context);
|
||||||
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
if (!client.Connect(endpoint, FLAGS_username, FLAGS_password)) {
|
||||||
throw utils::BasicException("Couldn't connect to server!");
|
throw utils::BasicException("Couldn't connect to server!");
|
||||||
}
|
}
|
||||||
|
@ -52,8 +52,9 @@ bool QueryServer(io::network::Socket &socket) {
|
|||||||
TEST(NetworkTimeouts, InactiveSession) {
|
TEST(NetworkTimeouts, InactiveSession) {
|
||||||
// Instantiate the server and set the session timeout to 2 seconds.
|
// Instantiate the server and set the session timeout to 2 seconds.
|
||||||
TestData test_data;
|
TestData test_data;
|
||||||
|
communication::ServerContext context;
|
||||||
communication::Server<TestSession, TestData> server{
|
communication::Server<TestSession, TestData> server{
|
||||||
{"127.0.0.1", 0}, test_data, 2, "Test", 1};
|
{"127.0.0.1", 0}, test_data, &context, 2, "Test", 1};
|
||||||
|
|
||||||
// Create the client and connect to the server.
|
// Create the client and connect to the server.
|
||||||
io::network::Socket client;
|
io::network::Socket client;
|
||||||
|
94
tests/unit/socket.cpp
Normal file
94
tests/unit/socket.cpp
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#include <chrono>
|
||||||
|
#include <csignal>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include <glog/logging.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "io/network/socket.hpp"
|
||||||
|
#include "utils/timer.hpp"
|
||||||
|
|
||||||
|
TEST(Socket, WaitForReadyRead) {
|
||||||
|
io::network::Socket server;
|
||||||
|
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
|
||||||
|
ASSERT_TRUE(server.Listen(1024));
|
||||||
|
|
||||||
|
std::thread thread([&server] {
|
||||||
|
io::network::Socket client;
|
||||||
|
ASSERT_TRUE(client.Connect(server.endpoint()));
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||||
|
ASSERT_TRUE(client.Write("test"));
|
||||||
|
});
|
||||||
|
|
||||||
|
uint8_t buff[100];
|
||||||
|
auto client = server.Accept();
|
||||||
|
ASSERT_TRUE(client);
|
||||||
|
|
||||||
|
client->SetNonBlocking();
|
||||||
|
|
||||||
|
ASSERT_EQ(client->Read(buff, sizeof(buff)), -1);
|
||||||
|
ASSERT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR);
|
||||||
|
|
||||||
|
utils::Timer timer;
|
||||||
|
ASSERT_TRUE(client->WaitForReadyRead());
|
||||||
|
ASSERT_GT(timer.Elapsed().count(), 1.0);
|
||||||
|
|
||||||
|
ASSERT_GT(client->Read(buff, sizeof(buff)), 0);
|
||||||
|
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Socket, WaitForReadyWrite) {
|
||||||
|
io::network::Socket server;
|
||||||
|
ASSERT_TRUE(server.Bind({"127.0.0.1", 0}));
|
||||||
|
ASSERT_TRUE(server.Listen(1024));
|
||||||
|
|
||||||
|
std::thread thread([&server] {
|
||||||
|
uint8_t buff[10000];
|
||||||
|
io::network::Socket client;
|
||||||
|
ASSERT_TRUE(client.Connect(server.endpoint()));
|
||||||
|
client.SetNonBlocking();
|
||||||
|
|
||||||
|
// Decrease the TCP read buffer.
|
||||||
|
int len = 1024;
|
||||||
|
ASSERT_EQ(setsockopt(client.fd(), SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)),
|
||||||
|
0);
|
||||||
|
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||||
|
while (true) {
|
||||||
|
int ret = client.Read(buff, sizeof(buff));
|
||||||
|
if (ret == -1 && errno != EAGAIN && errno != EWOULDBLOCK &&
|
||||||
|
errno != EINTR) {
|
||||||
|
std::raise(SIGPIPE);
|
||||||
|
} else if (ret == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
auto client = server.Accept();
|
||||||
|
ASSERT_TRUE(client);
|
||||||
|
|
||||||
|
client->SetNonBlocking();
|
||||||
|
|
||||||
|
// Decrease the TCP write buffer.
|
||||||
|
int len = 1024;
|
||||||
|
ASSERT_EQ(setsockopt(client->fd(), SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)),
|
||||||
|
0);
|
||||||
|
|
||||||
|
utils::Timer timer;
|
||||||
|
for (int i = 0; i < 1000000; ++i) {
|
||||||
|
ASSERT_TRUE(client->Write("test"));
|
||||||
|
}
|
||||||
|
ASSERT_GT(timer.Elapsed().count(), 1.0);
|
||||||
|
|
||||||
|
client->Close();
|
||||||
|
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user