Support Bolt v4.3 protocol ()

This commit is contained in:
Jure Bajic 2021-10-28 13:05:09 +02:00 committed by GitHub
parent e40bc32624
commit a9b1ff9bea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 748 additions and 428 deletions

View File

@ -31,6 +31,7 @@ enum class Signature : uint8_t {
Begin = 0x11,
Commit = 0x12,
Rollback = 0x13,
Route = 0x66,
Record = 0x71,
Success = 0x70,

View File

@ -28,7 +28,7 @@ static constexpr size_t kChunkWholeSize = kChunkHeaderSize + kChunkMaxDataSize;
*/
static constexpr size_t kHandshakeSize = 20;
static constexpr uint16_t kSupportedVersions[3] = {0x0100, 0x0400, 0x0401};
static constexpr uint16_t kSupportedVersions[] = {0x0100, 0x0400, 0x0401, 0x0403};
static constexpr int kPullAll = -1;
static constexpr int kPullLast = -1;

View File

@ -12,6 +12,7 @@
#pragma once
#include <map>
#include <memory>
#include <new>
#include <string>
@ -19,339 +20,69 @@
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/exceptions.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/states/handlers.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/exceptions.hpp"
#include "utils/likely.hpp"
#include "utils/logging.hpp"
#include "utils/message.hpp"
namespace communication::bolt {
// TODO (mferencevic): revise these error messages
inline std::pair<std::string, std::string> ExceptionToErrorMessage(const std::exception &e) {
if (auto *verbose = dynamic_cast<const VerboseError *>(&e)) {
return {verbose->code(), verbose->what()};
}
if (dynamic_cast<const ClientError *>(&e)) {
// Clients expect 4 strings separated by dots. First being database name
// (for example: Neo, Memgraph...), second being either ClientError,
// TransientError or DatabaseError (or ClientNotification for warnings).
// ClientError means wrong query, do not retry. DatabaseError means
// something wrong in database, do not retry. TransientError means query
// failed, but if retried it may succeed, retry it.
//
// Third and fourth strings being namespace and specific error name.
// It is not really important what we put there since we don't expect
// any special handling of specific exceptions on client side, but we
// need to make sure that we don't accidentally return some exception
// name which clients handle in a special way. For example, if client
// receives *.TransientError.Transaction.Terminate it will not rerun
// query even though TransientError was returned, because of Neo's
// semantics of that error.
return {"Memgraph.ClientError.MemgraphError.MemgraphError", e.what()};
}
if (dynamic_cast<const utils::BasicException *>(&e)) {
// Exception not derived from QueryException was thrown which means that
// database probably aborted transaction because of some timeout,
// deadlock, serialization error or something similar. We return
// TransientError since retry of same transaction could succeed.
return {"Memgraph.TransientError.MemgraphError.MemgraphError", e.what()};
}
if (dynamic_cast<const std::bad_alloc *>(&e)) {
// std::bad_alloc was thrown, God knows in which state is database ->
// terminate.
LOG_FATAL("Memgraph is out of memory");
}
// All exceptions used in memgraph are derived from BasicException. Since
// we caught some other exception we don't know what is going on. Return
// DatabaseError, log real message and return generic string.
spdlog::error(
utils::MessageWithLink("Unknown exception occurred during query execution {}.", e.what(), "https://memgr.ph/unknown"));
return {"Memgraph.DatabaseError.MemgraphError.MemgraphError",
"An unknown exception occurred, this is unexpected. Real message "
"should be in database logs."};
}
template <typename TSession>
inline State HandleFailure(TSession &session, const std::exception &e) {
spdlog::trace("Error message: {}", e.what());
if (const auto *p = dynamic_cast<const utils::StacktraceException *>(&e)) {
spdlog::trace("Error trace: {}", p->trace());
}
session.encoder_buffer_.Clear();
auto code_message = ExceptionToErrorMessage(e);
bool fail_sent = session.encoder_.MessageFailure({{"code", code_message.first}, {"message", code_message.second}});
if (!fail_sent) {
spdlog::trace("Couldn't send failure message!");
return State::Close;
}
return State::Error;
}
template <typename TSession>
State HandleRun(TSession &session, State state, Marker marker) {
const std::map<std::string, Value> kEmptyFields = {{"fields", std::vector<Value>{}}};
const auto expected_marker = session.version_.major == 1 ? Marker::TinyStruct2 : Marker::TinyStruct3;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3", utils::UnderlyingCast(marker));
return State::Close;
}
Value query, params, extra;
if (!session.decoder_.ReadValue(&query, Value::Type::String)) {
spdlog::trace("Couldn't read query string!");
return State::Close;
}
if (!session.decoder_.ReadValue(&params, Value::Type::Map)) {
spdlog::trace("Couldn't read parameters!");
return State::Close;
}
if (session.version_.major == 4) {
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
spdlog::trace("Couldn't read extra field!");
}
}
if (state != State::Idle) {
// Client could potentially recover if we move to error state, but there is
// no legitimate situation in which well working client would end up in this
// situation.
spdlog::trace("Unexpected RUN command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
spdlog::debug("[Run] '{}'", query.ValueString());
try {
// Interpret can throw.
auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap());
// Convert std::string to Value
std::vector<Value> vec;
std::map<std::string, Value> data;
vec.reserve(header.size());
for (auto &i : header) vec.emplace_back(std::move(i));
data.emplace("fields", std::move(vec));
// Send the header.
if (!session.encoder_.MessageSuccess(data)) {
spdlog::trace("Couldn't send query header!");
State RunHandlerV1(Signature signature, TSession &session, State state, Marker marker) {
switch (signature) {
case Signature::Run:
return HandleRunV1<TSession>(session, state, marker);
case Signature::Pull:
return HandlePullV1<TSession>(session, state, marker);
case Signature::Discard:
return HandleDiscardV1<TSession>(session, state, marker);
case Signature::Reset:
return HandleReset<TSession>(session, marker);
default:
spdlog::trace("Unrecognized signature received (0x{:02X})!", utils::UnderlyingCast(signature));
return State::Close;
}
return State::Result;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
namespace detail {
template <bool is_pull, typename TSession>
State HandlePullDiscard(TSession &session, State state, Marker marker) {
const auto expected_marker = session.version_.major == 1 ? Marker::TinyStruct : Marker::TinyStruct1;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct" : "TinyStruct1", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Result) {
if constexpr (is_pull) {
spdlog::trace("Unexpected PULL!");
} else {
spdlog::trace("Unexpected DISCARD!");
}
// Same as `unexpected RUN` case.
return State::Close;
}
try {
std::optional<int> n;
std::optional<int> qid;
if (session.version_.major == 4) {
Value extra;
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
spdlog::trace("Couldn't read extra field!");
}
const auto &extra_map = extra.ValueMap();
if (extra_map.count("n")) {
if (const auto n_value = extra_map.at("n").ValueInt(); n_value != kPullAll) {
n = n_value;
}
}
if (extra_map.count("qid")) {
if (const auto qid_value = extra_map.at("qid").ValueInt(); qid_value != kPullLast) {
qid = qid_value;
}
template <typename TSession, int bolt_minor = 0>
State RunHandlerV4(Signature signature, TSession &session, State state, Marker marker) {
switch (signature) {
case Signature::Run:
return HandleRunV4<TSession>(session, state, marker);
case Signature::Pull:
return HandlePullV4<TSession>(session, state, marker);
case Signature::Discard:
return HandleDiscardV4<TSession>(session, state, marker);
case Signature::Reset:
return HandleReset<TSession>(session, marker);
case Signature::Begin:
return HandleBegin<TSession>(session, state, marker);
case Signature::Commit:
return HandleCommit<TSession>(session, state, marker);
case Signature::Goodbye:
return HandleGoodbye<TSession>();
case Signature::Rollback:
return HandleRollback<TSession>(session, state, marker);
case Signature::Noop: {
if constexpr (bolt_minor >= 1) {
return HandleNoop<TSession>(state);
} else {
spdlog::trace("Supported only in bolt v4.1");
return State::Close;
}
}
std::map<std::string, Value> summary;
if constexpr (is_pull) {
// Pull can throw.
summary = session.Pull(&session.encoder_, n, qid);
} else {
summary = session.Discard(n, qid);
case Signature::Route: {
if constexpr (bolt_minor >= 3) {
if (signature == Signature::Route) return HandleRoute<TSession>(session);
} else {
spdlog::trace("Supported only in bolt v4.3");
return State::Close;
}
}
if (!session.encoder_.MessageSuccess(summary)) {
spdlog::trace("Couldn't send query summary!");
default:
spdlog::trace("Unrecognized signature received (0x{:02X})!", utils::UnderlyingCast(signature));
return State::Close;
}
if (summary.count("has_more") && summary.at("has_more").ValueBool()) {
return State::Result;
}
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
} // namespace detail
template <typename Session>
State HandlePull(Session &session, State state, Marker marker) {
return detail::HandlePullDiscard<true>(session, state, marker);
}
template <typename Session>
State HandleDiscard(Session &session, State state, Marker marker) {
return detail::HandlePullDiscard<false>(session, state, marker);
}
template <typename Session>
State HandleReset(Session &session, State, Marker marker) {
// IMPORTANT: This implementation of the Bolt RESET command isn't fully
// compliant to the protocol definition. In the protocol it is defined
// that this command should immediately stop any running commands and
// reset the session to a clean state. That means that we should always
// make a look-ahead for the RESET command before processing anything.
// Our implementation, for now, does everything in a blocking fashion
// so we cannot simply "kill" a transaction while it is running. So
// now this command only resets the session to a clean state. It
// does not IGNORE running and pending commands as it should.
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
return State::Close;
}
// Clear all pending data and send a success message.
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
session.Abort();
return State::Idle;
}
template <typename Session>
State HandleBegin(Session &session, State state, Marker marker) {
if (session.version_.major == 1) {
spdlog::trace("BEGIN messsage not supported in Bolt v1!");
return State::Close;
}
if (marker != Marker::TinyStruct1) {
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
Value extra;
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
spdlog::trace("Couldn't read extra fields!");
return State::Close;
}
if (state != State::Idle) {
spdlog::trace("Unexpected BEGIN command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
try {
session.BeginTransaction();
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
return State::Idle;
}
template <typename Session>
State HandleCommit(Session &session, State state, Marker marker) {
if (session.version_.major == 1) {
spdlog::trace("COMMIT messsage not supported in Bolt v1!");
return State::Close;
}
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Idle) {
spdlog::trace("Unexpected COMMIT command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try {
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
session.CommitTransaction();
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
template <typename Session>
State HandleRollback(Session &session, State state, Marker marker) {
if (session.version_.major == 1) {
spdlog::trace("ROLLBACK messsage not supported in Bolt v1!");
return State::Close;
}
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Idle) {
spdlog::trace("Unexpected ROLLBACK command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try {
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
session.RollbackTransaction();
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
@ -361,8 +92,8 @@ State HandleRollback(Session &session, State state, Marker marker) {
* It executes: RUN, PULL_ALL, DISCARD_ALL & RESET.
* @param session the session that should be used for the run
*/
template <typename Session>
State StateExecutingRun(Session &session, State state) {
template <typename TSession>
State StateExecutingRun(TSession &session, State state) {
Marker marker;
Signature signature;
if (!session.decoder_.ReadMessageHeader(&signature, &marker)) {
@ -370,30 +101,21 @@ State StateExecutingRun(Session &session, State state) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
spdlog::trace("Received NOOP message");
return state;
}
if (signature == Signature::Run) {
return HandleRun(session, state, marker);
} else if (signature == Signature::Pull) {
return HandlePull(session, state, marker);
} else if (signature == Signature::Discard) {
return HandleDiscard(session, state, marker);
} else if (signature == Signature::Begin) {
return HandleBegin(session, state, marker);
} else if (signature == Signature::Commit) {
return HandleCommit(session, state, marker);
} else if (signature == Signature::Rollback) {
return HandleRollback(session, state, marker);
} else if (signature == Signature::Reset) {
return HandleReset(session, state, marker);
} else if (signature == Signature::Goodbye && session.version_.major != 1) {
throw SessionClosedException("Closing connection.");
} else {
spdlog::trace("Unrecognized signature received (0x{:02X})!", utils::UnderlyingCast(signature));
return State::Close;
switch (session.version_.major) {
case 1:
return RunHandlerV1(signature, session, state, marker);
case 4: {
if (session.version_.minor >= 3) {
return RunHandlerV4<TSession, 3>(signature, session, state, marker);
}
if (session.version_.minor >= 1) {
return RunHandlerV4<TSession, 1>(signature, session, state, marker);
}
return RunHandlerV4<TSession>(signature, session, state, marker);
}
default:
spdlog::trace("Unsupported bolt version:{}.{})!", session.version_.major, session.version_.minor);
return State::Close;
}
}
} // namespace communication::bolt

View File

@ -0,0 +1,418 @@
// Copyright 2021 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
// License, and you may not use this file except in compliance with the Business Source License.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
#pragma once
#include <map>
#include <string>
#include <string_view>
#include <vector>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/exceptions.hpp"
#include "utils/logging.hpp"
#include "utils/message.hpp"
namespace communication::bolt {
// TODO: Revise these error messages
inline std::pair<std::string, std::string> ExceptionToErrorMessage(const std::exception &e) {
if (const auto *verbose = dynamic_cast<const VerboseError *>(&e)) {
return {verbose->code(), verbose->what()};
}
if (dynamic_cast<const ClientError *>(&e)) {
// Clients expect 4 strings separated by dots. First being database name
// (for example: Neo, Memgraph...), second being either ClientError,
// TransientError or DatabaseError (or ClientNotification for warnings).
// ClientError means wrong query, do not retry. DatabaseError means
// something wrong in database, do not retry. TransientError means query
// failed, but if retried it may succeed, retry it.
//
// Third and fourth strings being namespace and specific error name.
// It is not really important what we put there since we don't expect
// any special handling of specific exceptions on client side, but we
// need to make sure that we don't accidentally return some exception
// name which clients handle in a special way. For example, if client
// receives *.TransientError.Transaction.Terminate it will not rerun
// query even though TransientError was returned, because of Neo's
// semantics of that error.
return {"Memgraph.ClientError.MemgraphError.MemgraphError", e.what()};
}
if (dynamic_cast<const utils::BasicException *>(&e)) {
// Exception not derived from QueryException was thrown which means that
// database probably aborted transaction because of some timeout,
// deadlock, serialization error or something similar. We return
// TransientError since retry of same transaction could succeed.
return {"Memgraph.TransientError.MemgraphError.MemgraphError", e.what()};
}
if (dynamic_cast<const std::bad_alloc *>(&e)) {
// std::bad_alloc was thrown, God knows in which state is database ->
// terminate.
LOG_FATAL("Memgraph is out of memory");
}
// All exceptions used in memgraph are derived from BasicException. Since
// we caught some other exception we don't know what is going on. Return
// DatabaseError, log real message and return generic string.
spdlog::error(utils::MessageWithLink("Unknown exception occurred during query execution {}.", e.what(),
"https://memgr.ph/unknown"));
return {"Memgraph.DatabaseError.MemgraphError.MemgraphError",
"An unknown exception occurred, this is unexpected. Real message "
"should be in database logs."};
}
namespace details {
template <typename TSession>
State HandleRun(TSession &session, const State state, const Value &query, const Value &params) {
if (state != State::Idle) {
// Client could potentially recover if we move to error state, but there is
// no legitimate situation in which well working client would end up in this
// situation.
spdlog::trace("Unexpected RUN command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
spdlog::debug("[Run] '{}'", query.ValueString());
try {
// Interpret can throw.
const auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap());
// Convert std::string to Value
std::vector<Value> vec;
std::map<std::string, Value> data;
vec.reserve(header.size());
for (auto &i : header) vec.emplace_back(std::move(i));
data.emplace("fields", std::move(vec));
// Send the header.
if (!session.encoder_.MessageSuccess(data)) {
spdlog::trace("Couldn't send query header!");
return State::Close;
}
return State::Result;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
template <bool is_pull, typename TSession>
State HandlePullDiscard(TSession &session, std::optional<int> n, std::optional<int> qid) {
try {
std::map<std::string, Value> summary;
if constexpr (is_pull) {
// Pull can throw.
summary = session.Pull(&session.encoder_, n, qid);
} else {
summary = session.Discard(n, qid);
}
if (!session.encoder_.MessageSuccess(summary)) {
spdlog::trace("Couldn't send query summary!");
return State::Close;
}
if (summary.count("has_more") && summary.at("has_more").ValueBool()) {
return State::Result;
}
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
template <bool is_pull, typename TSession>
State HandlePullDiscardV1(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Result) {
if constexpr (is_pull) {
spdlog::trace("Unexpected PULL!");
} else {
spdlog::trace("Unexpected DISCARD!");
}
// Same as `unexpected RUN` case.
return State::Close;
}
return HandlePullDiscard<is_pull, TSession>(session, std::nullopt, std::nullopt);
}
template <bool is_pull, typename TSession>
State HandlePullDiscardV4(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct1;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct1", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Result) {
if constexpr (is_pull) {
spdlog::trace("Unexpected PULL!");
} else {
spdlog::trace("Unexpected DISCARD!");
}
// Same as `unexpected RUN` case.
return State::Close;
}
std::optional<int> n;
std::optional<int> qid;
Value extra;
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
spdlog::trace("Couldn't read extra field!");
}
const auto &extra_map = extra.ValueMap();
if (extra_map.count("n")) {
if (const auto n_value = extra_map.at("n").ValueInt(); n_value != kPullAll) {
n = n_value;
}
}
if (extra_map.count("qid")) {
if (const auto qid_value = extra_map.at("qid").ValueInt(); qid_value != kPullLast) {
qid = qid_value;
}
}
return HandlePullDiscard<is_pull, TSession>(session, n, qid);
}
} // namespace details
template <typename TSession>
inline State HandleFailure(TSession &session, const std::exception &e) {
spdlog::trace("Error message: {}", e.what());
if (const auto *p = dynamic_cast<const utils::StacktraceException *>(&e)) {
spdlog::trace("Error trace: {}", p->trace());
}
session.encoder_buffer_.Clear();
auto code_message = ExceptionToErrorMessage(e);
bool fail_sent = session.encoder_.MessageFailure({{"code", code_message.first}, {"message", code_message.second}});
if (!fail_sent) {
spdlog::trace("Couldn't send failure message!");
return State::Close;
}
return State::Error;
}
template <typename TSession>
State HandleRunV1(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct2;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3", utils::UnderlyingCast(marker));
return State::Close;
}
Value query, params;
if (!session.decoder_.ReadValue(&query, Value::Type::String)) {
spdlog::trace("Couldn't read query string!");
return State::Close;
}
if (!session.decoder_.ReadValue(&params, Value::Type::Map)) {
spdlog::trace("Couldn't read parameters!");
return State::Close;
}
return details::HandleRun(session, state, query, params);
}
template <typename TSession>
State HandleRunV4(TSession &session, const State state, const Marker marker) {
const auto expected_marker = Marker::TinyStruct3;
if (marker != expected_marker) {
spdlog::trace("Expected {} marker, but received 0x{:02X}!", "TinyStruct3", utils::UnderlyingCast(marker));
return State::Close;
}
Value query, params, extra;
if (!session.decoder_.ReadValue(&query, Value::Type::String)) {
spdlog::trace("Couldn't read query string!");
return State::Close;
}
if (!session.decoder_.ReadValue(&params, Value::Type::Map)) {
spdlog::trace("Couldn't read parameters!");
return State::Close;
}
// Even though this part seems unnecessary it is needed to move the buffer
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
spdlog::trace("Couldn't read extra field!");
}
return details::HandleRun(session, state, query, params);
}
template <typename TSession>
State HandlePullV1(TSession &session, const State state, const Marker marker) {
return details::HandlePullDiscardV1<true>(session, state, marker);
}
template <typename TSession>
State HandlePullV4(TSession &session, const State state, const Marker marker) {
return details::HandlePullDiscardV4<true>(session, state, marker);
}
template <typename TSession>
State HandleDiscardV1(TSession &session, const State state, const Marker marker) {
return details::HandlePullDiscardV1<false>(session, state, marker);
}
template <typename TSession>
State HandleDiscardV4(TSession &session, const State state, const Marker marker) {
return details::HandlePullDiscardV4<false>(session, state, marker);
}
template <typename TSession>
State HandleReset(TSession &session, const Marker marker) {
// IMPORTANT: This implementation of the Bolt RESET command isn't fully
// compliant to the protocol definition. In the protocol it is defined
// that this command should immediately stop any running commands and
// reset the session to a clean state. That means that we should always
// make a look-ahead for the RESET command before processing anything.
// Our implementation, for now, does everything in a blocking fashion
// so we cannot simply "kill" a transaction while it is running. So
// now this command only resets the session to a clean state. It
// does not IGNORE running and pending commands as it should.
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
return State::Close;
}
// Clear all pending data and send a success message.
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
session.Abort();
return State::Idle;
}
template <typename TSession>
State HandleBegin(TSession &session, const State state, const Marker marker) {
if (marker != Marker::TinyStruct1) {
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
Value extra;
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
spdlog::trace("Couldn't read extra fields!");
return State::Close;
}
if (state != State::Idle) {
spdlog::trace("Unexpected BEGIN command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
try {
session.BeginTransaction();
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
return State::Idle;
}
template <typename TSession>
State HandleCommit(TSession &session, const State state, const Marker marker) {
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Idle) {
spdlog::trace("Unexpected COMMIT command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try {
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
session.CommitTransaction();
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
template <typename TSession>
State HandleRollback(TSession &session, const State state, const Marker marker) {
if (marker != Marker::TinyStruct) {
spdlog::trace("Expected TinyStruct marker, but received 0x{:02x}!", utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Idle) {
spdlog::trace("Unexpected ROLLBACK command!");
return State::Close;
}
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
try {
if (!session.encoder_.MessageSuccess({})) {
spdlog::trace("Couldn't send success message!");
return State::Close;
}
session.RollbackTransaction();
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
template <typename TSession>
State HandleNoop(const State state) {
spdlog::trace("Received NOOP message");
return state;
}
template <typename TSession>
State HandleGoodbye() {
throw SessionClosedException("Closing connection.");
}
template <typename TSession>
State HandleRoute(TSession &session) {
// Route message is not implemented since it is neo4j specific, therefore we
// will receive it an inform user that there is no implementation.
session.encoder_buffer_.Clear();
bool fail_sent =
session.encoder_.MessageFailure({{"code", 66}, {"message", "Route message not supported in Memgraph!"}});
if (!fail_sent) {
spdlog::trace("Couldn't send failure message!");
return State::Close;
}
return State::Error;
}
} // namespace communication::bolt

View File

@ -12,6 +12,9 @@
#pragma once
#include <fmt/format.h>
#include <algorithm>
#include <cstdint>
#include <iterator>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/constants.hpp"
@ -21,6 +24,35 @@
namespace communication::bolt {
inline bool CopyProtocolInformationIfSupported(uint16_t version, uint8_t *protocol) {
const auto *supported_version = std::find(std::begin(kSupportedVersions), std::end(kSupportedVersions), version);
if (supported_version != std::end(kSupportedVersions)) {
std::memcpy(protocol, &version, sizeof(version));
return true;
}
return false;
}
inline bool CopyProtocolInformationIfSupportedWithOffset(auto data_position, uint8_t *protocol) {
struct bolt_range_version {
uint8_t offset;
uint8_t minor;
uint8_t major;
} bolt_range_version;
std::memcpy(&bolt_range_version, data_position, sizeof(bolt_range_version));
if (bolt_range_version.major == 0 || bolt_range_version.minor == 0) return false;
bolt_range_version.offset = std::min(bolt_range_version.offset, bolt_range_version.minor);
for (uint8_t i{0U}; i <= bolt_range_version.offset; i++) {
uint8_t current_minor = bolt_range_version.minor - i;
if (CopyProtocolInformationIfSupported(static_cast<uint16_t>((bolt_range_version.major << 8U) + current_minor),
protocol)) {
return true;
}
}
return false;
}
/**
* Handshake state run function
* This function runs everything to make a Bolt handshake with the client.
@ -29,7 +61,7 @@ namespace communication::bolt {
template <typename TSession>
State StateHandshakeRun(TSession &session) {
auto precmp = std::memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
if (UNLIKELY(precmp != 0)) {
if (precmp != 0) [[unlikely]] {
spdlog::trace("Received a wrong preamble!");
return State::Close;
}
@ -37,22 +69,22 @@ State StateHandshakeRun(TSession &session) {
DMG_ASSERT(session.input_stream_.size() >= kHandshakeSize, "Wrong size of the handshake data!");
auto dataPosition = session.input_stream_.data() + sizeof(kPreamble);
uint8_t protocol[4] = {0x00};
for (int i = 0; i < 4 && !protocol[3]; ++i) {
dataPosition += 2; // version is defined only by the last 2 bytes
uint16_t version = 0;
for (int i = 0; i < 4 && !protocol[3]; ++i) {
// If there is an offset defined (e.g. 0x00 0x03 0x03 0x04) the second byte
// That would enable the client to pick between 4.0 and 4.3 versions
// as per changes in handshake bolt protocol in v4.3
if (CopyProtocolInformationIfSupportedWithOffset(dataPosition + 1, protocol + 2)) break;
dataPosition += 2; // version is defined only by the last 2 bytes
uint16_t version{0};
std::memcpy(&version, dataPosition, sizeof(version));
if (!version) {
break;
}
for (const auto supportedVersion : kSupportedVersions) {
if (supportedVersion == version) {
std::memcpy(protocol + 2, &version, sizeof(version));
break;
}
if (CopyProtocolInformationIfSupported(version, protocol + 2)) {
break;
}
dataPosition += 2;

View File

@ -12,6 +12,7 @@
#pragma once
#include <fmt/format.h>
#include <optional>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/state.hpp"
@ -22,10 +23,45 @@
namespace communication::bolt {
namespace detail {
namespace details {
template <typename TSession>
std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
if (UNLIKELY(marker != Marker::TinyStruct2)) {
std::optional<State> AuthenticateUser(TSession &session, Value &metadata) {
// Get authentication data.
auto &data = metadata.ValueMap();
if (!data.count("scheme")) {
spdlog::warn("The client didn't supply authentication information!");
return State::Close;
}
std::string username;
std::string password;
if (data["scheme"].ValueString() == "basic") {
if (!data.count("principal") || !data.count("credentials")) {
spdlog::warn("The client didn't supply authentication information!");
return State::Close;
}
username = data["principal"].ValueString();
password = data["credentials"].ValueString();
} else if (data["scheme"].ValueString() != "none") {
spdlog::warn("Unsupported authentication scheme: {}", data["scheme"].ValueString());
return State::Close;
}
// Authenticate the user.
if (!session.Authenticate(username, password)) {
if (!session.encoder_.MessageFailure(
{{"code", "Memgraph.ClientError.Security.Unauthenticated"}, {"message", "Authentication failure"}})) {
spdlog::trace("Couldn't send failure message to the client!");
}
// Throw an exception to indicate to the network stack that the session
// should be closed and cleaned up.
throw SessionClosedException("The client is not authenticated!");
}
return std::nullopt;
}
template <typename TSession>
std::optional<Value> GetMetadataV1(TSession &session, const Marker marker) {
if (marker != Marker::TinyStruct2) [[unlikely]] {
spdlog::trace("Expected TinyStruct2 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
spdlog::trace(
"The client sent malformed data, but we are continuing "
@ -53,8 +89,8 @@ std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
}
template <typename TSession>
std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
if (UNLIKELY(marker != Marker::TinyStruct1)) {
std::optional<Value> GetMetadataV4(TSession &session, const Marker marker) {
if (marker != Marker::TinyStruct1) [[unlikely]] {
spdlog::trace("Expected TinyStruct1 marker, but received 0x{:02X}!", utils::UnderlyingCast(marker));
spdlog::trace(
"The client sent malformed data, but we are continuing "
@ -80,15 +116,79 @@ std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
return metadata;
}
} // namespace detail
template <typename TSession>
State SendSuccessMessage(TSession &session) {
// Neo4j's Java driver 4.1.1+ requires connection_id.
// The only usage in the mentioned version is for logging purposes.
// Because it's not critical for the regular usage of the driver
// we send a hardcoded value for now.
std::map<std::string, Value> metadata{{"connection_id", "bolt-1"}};
if (auto server_name = session.GetServerNameForInit(); server_name) {
metadata.insert({"server", *server_name});
}
bool success_sent = session.encoder_.MessageSuccess(metadata);
if (!success_sent) {
spdlog::trace("Couldn't send success message to the client!");
return State::Close;
}
return State::Idle;
}
template <typename TSession>
State StateInitRunV1(TSession &session, const Marker marker, const Signature signature) {
if (signature != Signature::Init) [[unlikely]] {
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
return State::Close;
}
auto maybeMetadata = GetMetadataV1(session, marker);
if (!maybeMetadata) {
return State::Close;
}
if (auto result = AuthenticateUser(session, *maybeMetadata)) {
return result.value();
}
return SendSuccessMessage(session);
}
template <typename TSession, int bolt_minor = 0>
State StateInitRunV4(TSession &session, Marker marker, Signature signature) {
if constexpr (bolt_minor > 0) {
if (signature == Signature::Noop) [[unlikely]] {
SPDLOG_DEBUG("Received NOOP message");
return State::Init;
}
}
if (signature != Signature::Init) [[unlikely]] {
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
return State::Close;
}
auto maybeMetadata = GetMetadataV4(session, marker);
if (!maybeMetadata) {
return State::Close;
}
if (auto result = AuthenticateUser(session, *maybeMetadata)) {
return result.value();
}
return SendSuccessMessage(session);
}
} // namespace details
/**
* Init state run function.
* This function runs everything to initialize a Bolt session with the client.
* @param session the session that should be used for the run.
*/
template <typename Session>
State StateInitRun(Session &session) {
template <typename TSession>
State StateInitRun(TSession &session) {
DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state");
Marker marker;
@ -98,72 +198,18 @@ State StateInitRun(Session &session) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 && session.version_.minor == 1)) {
SPDLOG_DEBUG("Received NOOP message");
return State::Init;
}
if (UNLIKELY(signature != Signature::Init)) {
spdlog::trace("Expected Init signature, but received 0x{:02X}!", utils::UnderlyingCast(signature));
return State::Close;
}
auto maybeMetadata =
session.version_.major == 1 ? detail::StateInitRunV1(session, marker) : detail::StateInitRunV4(session, marker);
if (!maybeMetadata) {
return State::Close;
}
// Get authentication data.
std::string username;
std::string password;
auto &data = maybeMetadata->ValueMap();
if (!data.count("scheme")) {
spdlog::warn("The client didn't supply authentication information!");
return State::Close;
}
if (data["scheme"].ValueString() == "basic") {
if (!data.count("principal") || !data.count("credentials")) {
spdlog::warn("The client didn't supply authentication information!");
return State::Close;
switch (session.version_.major) {
case 1: {
return details::StateInitRunV1<TSession>(session, marker, signature);
}
username = data["principal"].ValueString();
password = data["credentials"].ValueString();
} else if (data["scheme"].ValueString() != "none") {
spdlog::warn("Unsupported authentication scheme: {}", data["scheme"].ValueString());
return State::Close;
}
// Authenticate the user.
if (!session.Authenticate(username, password)) {
if (!session.encoder_.MessageFailure(
{{"code", "Memgraph.ClientError.Security.Unauthenticated"}, {"message", "Authentication failure"}})) {
spdlog::trace("Couldn't send failure message to the client!");
}
// Throw an exception to indicate to the network stack that the session
// should be closed and cleaned up.
throw SessionClosedException("The client is not authenticated!");
}
// Return success.
{
bool success_sent = false;
// Neo4j's Java driver 4.1.1+ requires connection_id.
// The only usage in the mentioned version is for logging purposes.
// Because it's not critical for the regular usage of the driver
// we send a hardcoded value for now.
std::map<std::string, Value> metadata{{"connection_id", "bolt-1"}};
if (auto server_name = session.GetServerNameForInit(); server_name) {
metadata.insert({"server", *server_name});
}
success_sent = session.encoder_.MessageSuccess(metadata);
if (!success_sent) {
spdlog::trace("Couldn't send success message to the client!");
return State::Close;
case 4: {
if (session.version_.minor > 0) {
return details::StateInitRunV4<TSession, 1>(session, marker, signature);
}
return details::StateInitRunV4<TSession>(session, marker, signature);
}
}
return State::Idle;
spdlog::trace("Unsupported bolt version:{}.{})!", session.version_.major, session.version_.minor);
return State::Close;
}
} // namespace communication::bolt

View File

@ -138,6 +138,7 @@ constexpr uint8_t pullall_req[] = {0xb1, 0x3f, 0xa0};
constexpr uint8_t pull_one_req[] = {0xb1, 0x3f, 0xa1, 0x81, 0x6e, 0x01};
constexpr uint8_t reset_req[] = {0xb0, 0x0f};
constexpr uint8_t goodbye[] = {0xb0, 0x02};
constexpr uint8_t rollback[] = {0xb0, 0x13};
} // namespace v4
namespace v4_1 {
@ -147,6 +148,10 @@ constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x01, 0x04};
constexpr uint8_t noop[] = {0x00, 0x00};
} // namespace v4_1
namespace v4_3 {
constexpr uint8_t route[]{0xb0, 0x60};
} // namespace v4_3
// Write bolt chunk header (length)
void WriteChunkHeader(TestInputStream &input_stream, uint16_t len) {
len = utils::HostToBigEndian(len);
@ -320,6 +325,56 @@ TEST(BoltSession, HandshakeMultiVersionRequest) {
}
}
TEST(BoltSession, HandshakeWithVersionOffset) {
// It pick the versions depending on the offset given by the second byte
{
INIT_VARS;
const uint8_t priority_request[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x03, 0x03, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t priority_response[] = {0x00, 0x00, 0x03, 0x04};
ExecuteHandshake(input_stream, session, output, priority_request, priority_response);
ASSERT_EQ(session.version_.minor, 3);
ASSERT_EQ(session.version_.major, 4);
}
// This should pick 4.3 version since 4.4 and 4.5 are not existant
{
INIT_VARS;
const uint8_t priority_request[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x03, 0x05, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t priority_response[] = {0x00, 0x00, 0x03, 0x04};
ExecuteHandshake(input_stream, session, output, priority_request, priority_response);
ASSERT_EQ(session.version_.minor, 3);
ASSERT_EQ(session.version_.major, 4);
}
// With multiple offsets
{
INIT_VARS;
const uint8_t priority_request[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x03, 0x03, 0x07, 0x00, 0x03,
0x03, 0x06, 0x00, 0x03, 0x03, 0x05, 0x00, 0x03, 0x03, 0x04};
const uint8_t priority_response[] = {0x00, 0x00, 0x03, 0x04};
ExecuteHandshake(input_stream, session, output, priority_request, priority_response);
ASSERT_EQ(session.version_.minor, 3);
ASSERT_EQ(session.version_.major, 4);
}
// Offset overflows
{
INIT_VARS;
const uint8_t priority_request[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x07, 0x06, 0x04, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t priority_response[] = {0x00, 0x00, 0x03, 0x04};
ExecuteHandshake(input_stream, session, output, priority_request, priority_response);
ASSERT_EQ(session.version_.minor, 3);
ASSERT_EQ(session.version_.major, 4);
}
// Using offset but no version supported
{
INIT_VARS;
const uint8_t no_supported_versions_request[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x03, 0x10, 0x04, 0x00, 0x00,
0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
ASSERT_THROW(ExecuteHandshake(input_stream, session, output, no_supported_versions_request), SessionException);
}
}
TEST(BoltSession, InitWrongSignature) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
@ -773,7 +828,7 @@ TEST(BoltSession, PartialPull) {
ASSERT_EQ(session.state_, State::Result);
PrintOutput(output);
int len, num = 0;
int len{0}, num{0};
while (output.size() > 0) {
len = (output[0] << 8) + output[1];
output.erase(output.begin(), output.begin() + len + 4);
@ -889,3 +944,49 @@ TEST(BoltSession, Noop) {
ASSERT_THROW(ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop)), SessionException);
}
}
TEST(BoltSession, Route) {
// Memgraph does not support route message, but it handles it
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)), SessionException);
}
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
ASSERT_THROW(ExecuteCommand(input_stream, session, v4_3::route, sizeof(v4_3::route)), SessionException);
}
}
TEST(BoltSession, Rollback) {
// v1 does not support ROLLBACK message
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ASSERT_THROW(ExecuteCommand(input_stream, session, v4::rollback, sizeof(v4::rollback)), SessionException);
}
// v4 supports ROLLBACK message
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
ExecuteCommand(input_stream, session, v4::rollback, sizeof(v4::rollback));
ASSERT_EQ(session.state_, State::Idle);
CheckSuccessMessage(output);
}
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req, v4::handshake_resp);
ASSERT_THROW(ExecuteCommand(input_stream, session, v4::rollback, sizeof(v4::rollback)), SessionException);
}
}