Add support for Bolt v4(.1) (#10)

* Added handshake support
* Add support for v4 hello and goodbye
* Add support for pulling n results
* Add support for transactions
* Add pull n for the dump
* Add support for NOOP
* Add support for multiple queries
* Update bolt session to support qid
* Update drivers test with multiple versions and go
* Extract failure handling into a function
* Use unique ptr instead of optional for query execution
* Destroy stream before query execution

Co-authored-by: Antonio Andelic <antonio.andelic@memgraph.io>
This commit is contained in:
antonio2368 2020-10-16 12:49:33 +02:00 committed by GitHub
parent 42ac5d4ea3
commit 0bcc1d67bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 2878 additions and 500 deletions

View File

@ -8,13 +8,18 @@ static constexpr uint8_t kPreamble[4] = {0x60, 0x60, 0xB0, 0x17};
static constexpr uint8_t kProtocol[4] = {0x00, 0x00, 0x00, 0x01};
enum class Signature : uint8_t {
Noop = 0x00,
Init = 0x01,
AckFailure = 0x0E,
AckFailure = 0x0E, // only v1
Reset = 0x0F,
Goodbye = 0x02,
Run = 0x10,
DiscardAll = 0x2F,
PullAll = 0x3F,
Discard = 0x2F,
Pull = 0x3F,
Begin = 0x11,
Commit = 0x12,
Rollback = 0x13,
Record = 0x71,
Success = 0x70,

View File

@ -1,6 +1,7 @@
#pragma once
#include <cstddef>
#include <cstdint>
namespace communication::bolt {
@ -15,4 +16,9 @@ static constexpr size_t kChunkWholeSize = kChunkHeaderSize + kChunkMaxDataSize;
* Handshake size defined in the Bolt protocol.
*/
static constexpr size_t kHandshakeSize = 20;
static constexpr uint16_t kSupportedVersions[3] = {0x0100, 0x0400, 0x0401};
static constexpr int kPullAll = -1;
static constexpr int kPullLast = -1;
} // namespace communication::bolt

View File

@ -85,7 +85,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
* Writes a DiscardAll message.
*
* From the Bolt v1 documentation:
* DiscardAllMessage (signature=0x2F) {
* DiscardMessage (signature=0x2F) {
* }
*
* @returns true if the data was successfully sent to the client
@ -93,7 +93,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
*/
bool MessageDiscardAll() {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct));
WriteRAW(utils::UnderlyingCast(Signature::DiscardAll));
WriteRAW(utils::UnderlyingCast(Signature::Discard));
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;
@ -113,7 +113,7 @@ class ClientEncoder : private BaseEncoder<Buffer> {
*/
bool MessagePullAll() {
WriteRAW(utils::UnderlyingCast(Marker::TinyStruct));
WriteRAW(utils::UnderlyingCast(Signature::PullAll));
WriteRAW(utils::UnderlyingCast(Signature::Pull));
// Try to flush all remaining data in the buffer, but tell it that we will
// send more data (the end of message chunk).
if (!buffer_.Flush(true)) return false;

View File

@ -49,14 +49,38 @@ class Session {
/**
* Process the given `query` with `params`.
* @return A pair which contains list of headers and qid which is set only
* if an explicit transaction was started.
*/
virtual std::vector<std::string> Interpret(
virtual std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query, const std::map<std::string, Value> &params) = 0;
/**
* Put results of the processed query in the `encoder`.
*
* @param n If set, defines amount of rows to be pulled from the result,
* otherwise all the rows are pulled.
* @param q If set, defines from which query to pull the results,
* otherwise the last query is used.
*/
virtual std::map<std::string, Value> PullAll(TEncoder *encoder) = 0;
virtual std::map<std::string, Value> Pull(TEncoder *encoder,
std::optional<int> n,
std::optional<int> qid) = 0;
/**
* Discard results of the processed query.
*
* @param n If set, defines amount of rows to be discarded from the result,
* otherwise all the rows are discarded.
* @param q If set, defines from which query to discard the results,
* otherwise the last query is used.
*/
virtual std::map<std::string, Value> Discard(std::optional<int> n,
std::optional<int> qid) = 0;
virtual void BeginTransaction() = 0;
virtual void CommitTransaction() = 0;
virtual void RollbackTransaction() = 0;
/** Aborts currently running query. */
virtual void Abort() = 0;
@ -142,6 +166,13 @@ class Session {
bool handshake_done_{false};
State state_{State::Handshake};
struct Version {
uint8_t major;
uint8_t minor;
};
Version version_;
private:
void ClientFailureInvalidData() {
// Set the state to Close.

View File

@ -7,6 +7,7 @@
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/value.hpp"
#include "utils/cast.hpp"
#include "utils/likely.hpp"
namespace communication::bolt {
@ -25,10 +26,17 @@ State StateErrorRun(TSession &session, State state) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
session.version_.minor == 1)) {
DLOG(INFO) << "Received NOOP message";
return state;
}
// Clear the data buffer if it has any leftover data.
session.encoder_buffer_.Clear();
if (signature == Signature::AckFailure || signature == Signature::Reset) {
if ((session.version_.major == 1 && signature == Signature::AckFailure) ||
signature == Signature::Reset) {
if (signature == Signature::AckFailure) {
DLOG(INFO) << "AckFailure received";
} else {
@ -39,6 +47,7 @@ State StateErrorRun(TSession &session, State state) {
DLOG(WARNING) << "Couldn't send success message!";
return State::Close;
}
if (signature == Signature::Reset) {
session.Abort();
return State::Idle;

View File

@ -7,9 +7,12 @@
#include <glog/logging.h>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/exceptions.hpp"
#include "communication/bolt/v1/state.hpp"
#include "communication/bolt/v1/value.hpp"
#include "communication/exceptions.hpp"
#include "utils/likely.hpp"
namespace communication::bolt {
@ -59,19 +62,39 @@ inline std::pair<std::string, std::string> ExceptionToErrorMessage(
"should be in database logs."};
}
template <typename TSession>
inline State HandleFailure(TSession &session, const std::exception &e) {
DLOG(WARNING) << fmt::format("Error message: {}", e.what());
if (const auto *p = dynamic_cast<const utils::StacktraceException *>(&e)) {
DLOG(WARNING) << fmt::format("Error trace: {}", p->trace());
}
session.encoder_buffer_.Clear();
auto code_message = ExceptionToErrorMessage(e);
bool fail_sent = session.encoder_.MessageFailure(
{{"code", code_message.first}, {"message", code_message.second}});
if (!fail_sent) {
DLOG(WARNING) << "Couldn't send failure message!";
return State::Close;
}
return State::Error;
}
template <typename TSession>
State HandleRun(TSession &session, State state, Marker marker) {
const std::map<std::string, Value> kEmptyFields = {
{"fields", std::vector<Value>{}}};
if (marker != Marker::TinyStruct2) {
const auto expected_marker =
session.version_.major == 1 ? Marker::TinyStruct2 : Marker::TinyStruct3;
if (marker != expected_marker) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct2 marker, but received 0x{:02X}!",
"Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct2" : "TinyStruct3",
utils::UnderlyingCast(marker));
return State::Close;
}
Value query, params;
Value query, params, extra;
if (!session.decoder_.ReadValue(&query, Value::Type::String)) {
DLOG(WARNING) << "Couldn't read query string!";
return State::Close;
@ -82,6 +105,12 @@ State HandleRun(TSession &session, State state, Marker marker) {
return State::Close;
}
if (session.version_.major == 4) {
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
DLOG(WARNING) << "Couldn't read extra field!";
}
}
if (state != State::Idle) {
// Client could potentially recover if we move to error state, but there is
// no legitimate situation in which well working client would end up in this
@ -97,7 +126,8 @@ State HandleRun(TSession &session, State state, Marker marker) {
try {
// Interpret can throw.
auto header = session.Interpret(query.ValueString(), params.ValueMap());
auto [header, qid] =
session.Interpret(query.ValueString(), params.ValueMap());
// Convert std::string to Value
std::vector<Value> vec;
std::map<std::string, Value> data;
@ -111,85 +141,90 @@ State HandleRun(TSession &session, State state, Marker marker) {
}
return State::Result;
} catch (const std::exception &e) {
DLOG(WARNING) << fmt::format("Error message: {}", e.what());
if (const auto *p = dynamic_cast<const utils::StacktraceException *>(&e)) {
DLOG(WARNING) << fmt::format("Error trace: {}", p->trace());
}
session.encoder_buffer_.Clear();
auto code_message = ExceptionToErrorMessage(e);
bool fail_sent = session.encoder_.MessageFailure(
{{"code", code_message.first}, {"message", code_message.second}});
if (!fail_sent) {
DLOG(WARNING) << "Couldn't send failure message!";
return State::Close;
}
return State::Error;
return HandleFailure(session, e);
}
}
template <typename Session>
State HandlePullAll(Session &session, State state, Marker marker) {
if (marker != Marker::TinyStruct) {
namespace detail {
template <bool is_pull, typename TSession>
State HandlePullDiscard(TSession &session, State state, Marker marker) {
const auto expected_marker =
session.version_.major == 1 ? Marker::TinyStruct : Marker::TinyStruct1;
if (marker != expected_marker) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct marker, but received 0x{:02X}!",
"Expected {} marker, but received 0x{:02X}!",
session.version_.major == 1 ? "TinyStruct" : "TinyStruct1",
utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Result) {
DLOG(WARNING) << "Unexpected PULL_ALL!";
if constexpr (is_pull) {
DLOG(WARNING) << "Unexpected PULL!";
} else {
DLOG(WARNING) << "Unexpected DISCARD!";
}
// Same as `unexpected RUN` case.
return State::Close;
}
try {
// PullAll can throw.
auto summary = session.PullAll(&session.encoder_);
std::optional<int> n;
std::optional<int> qid;
if (session.version_.major == 4) {
Value extra;
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
DLOG(WARNING) << "Couldn't read extra field!";
}
const auto &extra_map = extra.ValueMap();
if (extra_map.count("n")) {
if (const auto n_value = extra_map.at("n").ValueInt();
n_value != kPullAll) {
n = n_value;
}
}
if (extra_map.count("qid")) {
if (const auto qid_value = extra_map.at("qid").ValueInt();
qid_value != kPullLast) {
qid = qid_value;
}
}
}
std::map<std::string, Value> summary;
if constexpr (is_pull) {
// Pull can throw.
summary = session.Pull(&session.encoder_, n, qid);
} else {
summary = session.Discard(n, qid);
}
if (!session.encoder_.MessageSuccess(summary)) {
DLOG(WARNING) << "Couldn't send query summary!";
return State::Close;
}
if (summary.count("has_more") && summary.at("has_more").ValueBool()) {
return State::Result;
}
return State::Idle;
} catch (const std::exception &e) {
DLOG(WARNING) << fmt::format("Error message: {}", e.what());
if (const auto *p = dynamic_cast<const utils::StacktraceException *>(&e)) {
DLOG(WARNING) << fmt::format("Error trace: {}", p->trace());
}
session.encoder_buffer_.Clear();
auto code_message = ExceptionToErrorMessage(e);
bool fail_sent = session.encoder_.MessageFailure(
{{"code", code_message.first}, {"message", code_message.second}});
if (!fail_sent) {
DLOG(WARNING) << "Couldn't send failure message!";
return State::Close;
}
return State::Error;
return HandleFailure(session, e);
}
}
} // namespace detail
template <typename Session>
State HandlePull(Session &session, State state, Marker marker) {
return detail::HandlePullDiscard<true>(session, state, marker);
}
template <typename Session>
State HandleDiscardAll(Session &session, State state, Marker marker) {
if (marker != Marker::TinyStruct) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Result) {
DLOG(WARNING) << "Unexpected DISCARD_ALL!";
// Same as `unexpected RUN` case.
return State::Close;
}
// Clear all pending data and send a success message.
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
DLOG(WARNING) << "Couldn't send success message!";
return State::Close;
}
return State::Idle;
State HandleDiscard(Session &session, State state, Marker marker) {
return detail::HandlePullDiscard<false>(session, state, marker);
}
template <typename Session>
@ -212,6 +247,7 @@ State HandleReset(Session &session, State, Marker marker) {
// Clear all pending data and send a success message.
session.encoder_buffer_.Clear();
if (!session.encoder_.MessageSuccess()) {
DLOG(WARNING) << "Couldn't send success message!";
return State::Close;
@ -222,6 +258,116 @@ State HandleReset(Session &session, State, Marker marker) {
return State::Idle;
}
template <typename Session>
State HandleBegin(Session &session, State state, Marker marker) {
if (session.version_.major == 1) {
DLOG(WARNING) << "BEGIN messsage not supported in Bolt v1!";
return State::Close;
}
if (marker != Marker::TinyStruct1) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct1 marker, but received 0x{:02x}!",
utils::UnderlyingCast(marker));
return State::Close;
}
Value extra;
if (!session.decoder_.ReadValue(&extra, Value::Type::Map)) {
DLOG(WARNING) << "Couldn't read extra fields!";
return State::Close;
}
if (state != State::Idle) {
DLOG(WARNING) << "Unexpected BEGIN command!";
return State::Close;
}
DCHECK(!session.encoder_buffer_.HasData())
<< "There should be no data to write in this state";
if (!session.encoder_.MessageSuccess({})) {
DLOG(WARNING) << "Couldn't send success message!";
return State::Close;
}
try {
session.BeginTransaction();
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
return State::Idle;
}
template <typename Session>
State HandleCommit(Session &session, State state, Marker marker) {
if (session.version_.major == 1) {
DLOG(WARNING) << "COMMIT messsage not supported in Bolt v1!";
return State::Close;
}
if (marker != Marker::TinyStruct) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct marker, but received 0x{:02x}!",
utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Idle) {
DLOG(WARNING) << "Unexpected COMMIT command!";
return State::Close;
}
DCHECK(!session.encoder_buffer_.HasData())
<< "There should be no data to write in this state";
try {
if (!session.encoder_.MessageSuccess({})) {
DLOG(WARNING) << "Couldn't send success message!";
return State::Close;
}
session.CommitTransaction();
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
template <typename Session>
State HandleRollback(Session &session, State state, Marker marker) {
if (session.version_.major == 1) {
DLOG(WARNING) << "ROLLBACK messsage not supported in Bolt v1!";
return State::Close;
}
if (marker != Marker::TinyStruct) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct marker, but received 0x{:02x}!",
utils::UnderlyingCast(marker));
return State::Close;
}
if (state != State::Idle) {
DLOG(WARNING) << "Unexpected ROLLBACK command!";
return State::Close;
}
DCHECK(!session.encoder_buffer_.HasData())
<< "There should be no data to write in this state";
try {
if (!session.encoder_.MessageSuccess({})) {
DLOG(WARNING) << "Couldn't send success message!";
return State::Close;
}
session.RollbackTransaction();
return State::Idle;
} catch (const std::exception &e) {
return HandleFailure(session, e);
}
}
/**
* Executor state run function
* This function executes an initialized Bolt session.
@ -237,16 +383,30 @@ State StateExecutingRun(Session &session, State state) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
session.version_.minor == 1)) {
DLOG(INFO) << "Received NOOP message";
return state;
}
if (signature == Signature::Run) {
return HandleRun(session, state, marker);
} else if (signature == Signature::PullAll) {
return HandlePullAll(session, state, marker);
} else if (signature == Signature::DiscardAll) {
return HandleDiscardAll(session, state, marker);
} else if (signature == Signature::Pull) {
return HandlePull(session, state, marker);
} else if (signature == Signature::Discard) {
return HandleDiscard(session, state, marker);
} else if (signature == Signature::Begin) {
return HandleBegin(session, state, marker);
} else if (signature == Signature::Commit) {
return HandleCommit(session, state, marker);
} else if (signature == Signature::Rollback) {
return HandleRollback(session, state, marker);
} else if (signature == Signature::Reset) {
return HandleReset(session, state, marker);
} else if (signature == Signature::Goodbye && session.version_.major != 1) {
throw SessionClosedException("Closing connection.");
} else {
DLOG(WARNING) << fmt::format("Unrecognized signature recieved (0x{:02X})!",
DLOG(WARNING) << fmt::format("Unrecognized signature received (0x{:02X})!",
utils::UnderlyingCast(signature));
return State::Close;
}

View File

@ -2,6 +2,8 @@
#include <glog/logging.h>
#include <fmt/format.h>
#include "communication/bolt/v1/codes.hpp"
#include "communication/bolt/v1/constants.hpp"
#include "communication/bolt/v1/state.hpp"
@ -16,25 +18,57 @@ namespace communication::bolt {
*/
template <typename TSession>
State StateHandshakeRun(TSession &session) {
auto precmp = memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
auto precmp =
std::memcmp(session.input_stream_.data(), kPreamble, sizeof(kPreamble));
if (UNLIKELY(precmp != 0)) {
DLOG(WARNING) << "Received a wrong preamble!";
return State::Close;
}
// TODO so far we only support version 1 of the protocol so it doesn't
// make sense to check which version the client prefers this will change in
// the future.
DCHECK(session.input_stream_.size() >= kHandshakeSize)
<< "Wrong size of the handshake data!";
if (!session.output_stream_.Write(kProtocol, sizeof(kProtocol))) {
auto dataPosition = session.input_stream_.data() + sizeof(kPreamble);
uint8_t protocol[4] = {0x00};
for (int i = 0; i < 4 && !protocol[3]; ++i) {
dataPosition += 2; // version is defined only by the last 2 bytes
uint16_t version = 0;
std::memcpy(&version, dataPosition, sizeof(version));
if (!version) {
break;
}
for (const auto supportedVersion : kSupportedVersions) {
if (supportedVersion == version) {
std::memcpy(protocol + 2, &version, sizeof(version));
break;
}
}
dataPosition += 2;
}
session.version_.minor = protocol[2];
session.version_.major = protocol[3];
if (!session.version_.major) {
DLOG(WARNING) << "Server doesn't support any of the requested versions!";
return State::Close;
}
if (!session.output_stream_.Write(protocol, sizeof(protocol))) {
DLOG(WARNING) << "Couldn't write handshake response!";
return State::Close;
}
DLOG(INFO) << fmt::format("Using version {}.{} of protocol",
session.version_.major, session.version_.minor);
// Delete data from the input stream. It is guaranteed that there will more
// than, or equal to 20 bytes (kHandshakeSize) in the buffer.
session.input_stream_.Shift(kHandshakeSize);
return State::Init;
}
}
} // namespace communication::bolt

View File

@ -11,6 +11,71 @@
namespace communication::bolt {
namespace detail {
template <typename TSession>
std::optional<Value> StateInitRunV1(TSession &session, const Marker marker) {
if (UNLIKELY(marker != Marker::TinyStruct2)) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct2 marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
DLOG(WARNING) << "The client sent malformed data, but we are continuing "
"because the official Neo4j Java driver sends malformed "
"data. D'oh!";
// TODO: this should be uncommented when the Neo4j Java driver is fixed
// return State::Close;
}
Value client_name;
if (!session.decoder_.ReadValue(&client_name, Value::Type::String)) {
DLOG(WARNING) << "Couldn't read client name!";
return std::nullopt;
}
Value metadata;
if (!session.decoder_.ReadValue(&metadata, Value::Type::Map)) {
DLOG(WARNING) << "Couldn't read metadata!";
return std::nullopt;
}
LOG(INFO) << fmt::format("Client connected '{}'", client_name.ValueString())
<< std::endl;
return metadata;
}
template <typename TSession>
std::optional<Value> StateInitRunV4(TSession &session, const Marker marker) {
if (UNLIKELY(marker != Marker::TinyStruct1)) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct1 marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
DLOG(WARNING) << "The client sent malformed data, but we are continuing "
"because the official Neo4j Java driver sends malformed "
"data. D'oh!";
// TODO: this should be uncommented when the Neo4j Java driver is fixed
// return State::Close;
}
Value metadata;
if (!session.decoder_.ReadValue(&metadata, Value::Type::Map)) {
DLOG(WARNING) << "Couldn't read metadata!";
return std::nullopt;
}
const auto &data = metadata.ValueMap();
if (!data.count("user_agent")) {
LOG(WARNING) << "The client didn't supply the user agent!";
return std::nullopt;
}
LOG(INFO) << fmt::format("Client connected '{}'",
data.at("user_agent").ValueString())
<< std::endl;
return metadata;
}
} // namespace detail
/**
* Init state run function.
* This function runs everything to initialize a Bolt session with the client.
@ -28,41 +93,31 @@ State StateInitRun(Session &session) {
return State::Close;
}
if (UNLIKELY(signature == Signature::Noop && session.version_.major == 4 &&
session.version_.minor == 1)) {
DLOG(INFO) << "Received NOOP message";
return State::Init;
}
if (UNLIKELY(signature != Signature::Init)) {
DLOG(WARNING) << fmt::format(
"Expected Init signature, but received 0x{:02X}!",
utils::UnderlyingCast(signature));
return State::Close;
}
if (UNLIKELY(marker != Marker::TinyStruct2)) {
DLOG(WARNING) << fmt::format(
"Expected TinyStruct2 marker, but received 0x{:02X}!",
utils::UnderlyingCast(marker));
DLOG(WARNING) << "The client sent malformed data, but we are continuing "
"because the official Neo4j Java driver sends malformed "
"data. D'oh!";
// TODO: this should be uncommented when the Neo4j Java driver is fixed
// return State::Close;
}
Value client_name;
if (!session.decoder_.ReadValue(&client_name, Value::Type::String)) {
DLOG(WARNING) << "Couldn't read client name!";
auto maybeMetadata = session.version_.major == 1
? detail::StateInitRunV1(session, marker)
: detail::StateInitRunV4(session, marker);
if (!maybeMetadata) {
return State::Close;
}
Value metadata;
if (!session.decoder_.ReadValue(&metadata, Value::Type::Map)) {
DLOG(WARNING) << "Couldn't read metadata!";
return State::Close;
}
LOG(INFO) << fmt::format("Client connected '{}'", client_name.ValueString())
<< std::endl;
// Get authentication data.
std::string username, password;
auto &data = metadata.ValueMap();
std::string username;
std::string password;
auto &data = maybeMetadata->ValueMap();
if (!data.count("scheme")) {
LOG(WARNING) << "The client didn't supply authentication information!";
return State::Close;
@ -95,13 +150,15 @@ State StateInitRun(Session &session) {
// Return success.
{
bool success_sent = false;
auto server_name = session.GetServerNameForInit();
if (server_name) {
success_sent =
session.encoder_.MessageSuccess({{"server", *server_name}});
} else {
success_sent = session.encoder_.MessageSuccess();
// Neo4j's Java driver 4.1.1+ requires connection_id.
// The only usage in the mentioned version is for logging purposes.
// Because it's not critical for the regular usage of the driver
// we send a hardcoded value for now.
std::map<std::string, Value> metadata{{"connection_id", "bolt-1"}};
if (auto server_name = session.GetServerNameForInit(); server_name) {
metadata.insert({"server", *server_name});
}
success_sent = session.encoder_.MessageSuccess(metadata);
if (!success_sent) {
DLOG(WARNING) << "Couldn't send success message to the client!";
return State::Close;

View File

@ -15,6 +15,7 @@
#include <gflags/gflags.h>
#include <glog/logging.h>
#include "communication/bolt/v1/constants.hpp"
#include "helpers.hpp"
#include "py/py.hpp"
#include "query/exceptions.hpp"
@ -213,7 +214,13 @@ class BoltSession final
using communication::bolt::Session<communication::InputStream,
communication::OutputStream>::TEncoder;
std::vector<std::string> Interpret(
void BeginTransaction() override { interpreter_.BeginTransaction(); }
void CommitTransaction() override { interpreter_.CommitTransaction(); }
void RollbackTransaction() override { interpreter_.RollbackTransaction(); }
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query,
const std::map<std::string, communication::bolt::Value> &params)
override {
@ -229,7 +236,7 @@ class BoltSession final
#ifdef MG_ENTERPRISE
if (user_) {
const auto &permissions = user_->GetPermissions();
for (const auto &privilege : result.second) {
for (const auto &privilege : result.privileges) {
if (permissions.Has(glue::PrivilegeToPermission(privilege)) !=
auth::PermissionLevel::GRANT) {
interpreter_.Abort();
@ -240,7 +247,7 @@ class BoltSession final
}
}
#endif
return result.first;
return {result.headers, result.qid};
} catch (const query::QueryException &e) {
// Wrap QueryException into ClientError, because we want to allow the
@ -249,11 +256,43 @@ class BoltSession final
}
}
std::map<std::string, communication::bolt::Value> PullAll(
TEncoder *encoder) override {
std::map<std::string, communication::bolt::Value> Pull(
TEncoder *encoder, std::optional<int> n,
std::optional<int> qid) override {
TypedValueResultStream stream(encoder, db_);
return PullResults(stream, n, qid);
}
std::map<std::string, communication::bolt::Value> Discard(
std::optional<int> n, std::optional<int> qid) override {
DiscardValueResultStream stream;
return PullResults(stream, n, qid);
}
void Abort() override { interpreter_.Abort(); }
bool Authenticate(const std::string &username,
const std::string &password) override {
#ifdef MG_ENTERPRISE
if (!auth_->HasUsers()) return true;
user_ = auth_->Authenticate(username, password);
return !!user_;
#else
return true;
#endif
}
std::optional<std::string> GetServerNameForInit() override {
if (FLAGS_bolt_server_name_for_init.empty()) return std::nullopt;
return FLAGS_bolt_server_name_for_init;
}
private:
template <typename TStream>
std::map<std::string, communication::bolt::Value> PullResults(
TStream &stream, std::optional<int> n, std::optional<int> qid) {
try {
TypedValueResultStream stream(encoder, db_);
const auto &summary = interpreter_.PullAll(&stream);
const auto &summary = interpreter_.Pull(&stream, n, qid);
std::map<std::string, communication::bolt::Value> decoded_summary;
for (const auto &kv : summary) {
auto maybe_value =
@ -279,25 +318,6 @@ class BoltSession final
}
}
void Abort() override { interpreter_.Abort(); }
bool Authenticate(const std::string &username,
const std::string &password) override {
#ifdef MG_ENTERPRISE
if (!auth_->HasUsers()) return true;
user_ = auth_->Authenticate(username, password);
return !!user_;
#else
return true;
#endif
}
std::optional<std::string> GetServerNameForInit() override {
if (FLAGS_bolt_server_name_for_init.empty()) return std::nullopt;
return FLAGS_bolt_server_name_for_init;
}
private:
/// Wrapper around TEncoder which converts TypedValue to Value
/// before forwarding the calls to original TEncoder.
class TypedValueResultStream {
@ -336,6 +356,12 @@ class BoltSession final
const storage::Storage *db_;
};
struct DiscardValueResultStream {
void Result(const std::vector<query::TypedValue> &) {
// do nothing
}
};
// NOTE: Needed only for ToBoltValue conversions
const storage::Storage *db_;
query::Interpreter interpreter_;

View File

@ -10,7 +10,11 @@
#include <glog/logging.h>
#include "query/db_accessor.hpp"
#include "query/exceptions.hpp"
#include "query/stream.hpp"
#include "query/typed_value.hpp"
#include "storage/v2/storage.hpp"
#include "utils/algorithm.hpp"
#include "utils/string.hpp"
@ -219,71 +223,310 @@ void DumpUniqueConstraint(std::ostream *os, query::DbAccessor *dba,
} // namespace
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream) {
{
auto info = dba->ListAllIndices();
for (const auto &item : info.label) {
std::ostringstream os;
DumpLabelIndex(&os, dba, item);
stream->Result({TypedValue(os.str())});
}
for (const auto &item : info.label_property) {
std::ostringstream os;
DumpLabelPropertyIndex(&os, dba, item.first, item.second);
stream->Result({TypedValue(os.str())});
}
}
{
auto info = dba->ListAllConstraints();
for (const auto &item : info.existence) {
std::ostringstream os;
DumpExistenceConstraint(&os, dba, item.first, item.second);
stream->Result({TypedValue(os.str())});
}
for (const auto &item : info.unique) {
std::ostringstream os;
DumpUniqueConstraint(&os, dba, item.first, item.second);
stream->Result({TypedValue(os.str())});
}
}
PullPlanDump::PullPlanDump(DbAccessor *dba)
: dba_(dba),
vertices_iterable_(dba->Vertices(storage::View::OLD)),
pull_chunks_{// Dump all label indices
CreateLabelIndicesPullChunk(),
// Dump all label property indices
CreateLabelPropertyIndicesPullChunk(),
// Dump all existence constraints
CreateExistenceConstraintsPullChunk(),
// Dump all unique constraints
CreateUniqueConstraintsPullChunk(),
// Create internal index for faster edge creation
CreateInternalIndexPullChunk(),
// Dump all vertices
CreateVertexPullChunk(),
// Dump all edges
CreateEdgePullChunk(),
// Drop the internal index
CreateDropInternalIndexPullChunk(),
// Internal index cleanup
CreateInternalIndexCleanupPullChunk()} {}
auto vertices = dba->Vertices(storage::View::OLD);
bool internal_index_created = false;
if (vertices.begin() != vertices.end()) {
std::ostringstream os;
os << "CREATE INDEX ON :" << kInternalVertexLabel << "("
<< kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
internal_index_created = true;
}
for (const auto &vertex : vertices) {
std::ostringstream os;
DumpVertex(&os, dba, vertex);
stream->Result({TypedValue(os.str())});
}
for (const auto &vertex : vertices) {
auto maybe_edges = vertex.OutEdges(storage::View::OLD);
CHECK(maybe_edges.HasValue()) << "Invalid database state!";
for (const auto &edge : *maybe_edges) {
std::ostringstream os;
DumpEdge(&os, dba, edge);
stream->Result({TypedValue(os.str())});
bool PullPlanDump::Pull(AnyStream *stream, std::optional<int> n) {
// Iterate all functions that stream some results.
// Each function should return number of results it streamed after it
// finishes. If the function did not finish streaming all the results,
// std::nullopt should be returned because n results have already been sent.
while (current_chunk_index_ < pull_chunks_.size() && (!n || *n > 0)) {
const auto maybe_streamed_count =
pull_chunks_[current_chunk_index_](stream, n);
if (!maybe_streamed_count) {
// n wasn't large enough to stream all the results from the current chunk
break;
}
if (n) {
// chunk finished streaming its results
// subtract number of results streamed in current pull
// so we know how many results we need to stream from future
// chunks.
*n -= *maybe_streamed_count;
}
++current_chunk_index_;
}
if (internal_index_created) {
{
return current_chunk_index_ == pull_chunks_.size();
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelIndicesPullChunk() {
// Dump all label indices
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
}
const auto &label = indices_info_->label;
size_t local_counter = 0;
while (global_index < label.size() && (!n || local_counter < *n)) {
std::ostringstream os;
DumpLabelIndex(&os, dba_, label[global_index]);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == label.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateLabelPropertyIndicesPullChunk() {
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of indices vectors
if (!indices_info_) {
indices_info_.emplace(dba_->ListAllIndices());
}
const auto &label_property = indices_info_->label_property;
size_t local_counter = 0;
while (global_index < label_property.size() && (!n || local_counter < *n)) {
std::ostringstream os;
const auto &label_property_index = label_property[global_index];
DumpLabelPropertyIndex(&os, dba_, label_property_index.first,
label_property_index.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == label_property.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateExistenceConstraintsPullChunk() {
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of constraint vectors
if (!constraints_info_) {
constraints_info_.emplace(dba_->ListAllConstraints());
}
const auto &existence = constraints_info_->existence;
size_t local_counter = 0;
while (global_index < existence.size() && (!n || local_counter < *n)) {
const auto &constraint = existence[global_index];
std::ostringstream os;
DumpExistenceConstraint(&os, dba_, constraint.first, constraint.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == existence.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateUniqueConstraintsPullChunk() {
return [this, global_index = 0U](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the construction of constraint vectors
if (!constraints_info_) {
constraints_info_.emplace(dba_->ListAllConstraints());
}
const auto &unique = constraints_info_->unique;
size_t local_counter = 0;
while (global_index < unique.size() && (!n || local_counter < *n)) {
const auto &constraint = unique[global_index];
std::ostringstream os;
DumpUniqueConstraint(&os, dba_, constraint.first, constraint.second);
stream->Result({TypedValue(os.str())});
++global_index;
++local_counter;
}
if (global_index == unique.size()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexPullChunk() {
return [this](AnyStream *stream,
std::optional<int>) mutable -> std::optional<size_t> {
if (vertices_iterable_.begin() != vertices_iterable_.end()) {
std::ostringstream os;
os << "CREATE INDEX ON :" << kInternalVertexLabel << "("
<< kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
internal_index_created_ = true;
return 1;
}
return 0;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateVertexPullChunk() {
return [this,
maybe_current_iter = std::optional<VertexAccessorIterableIterator>{}](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_iter) {
maybe_current_iter.emplace(vertices_iterable_.begin());
}
auto &current_iter{*maybe_current_iter};
size_t local_counter = 0;
while (current_iter != vertices_iterable_.end() &&
(!n || local_counter < *n)) {
std::ostringstream os;
DumpVertex(&os, dba_, *current_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
++current_iter;
}
if (current_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateEdgePullChunk() {
return
[this,
maybe_current_vertex_iter =
std::optional<VertexAccessorIterableIterator>{},
// we need to save the iterable which contains list of accessor so
// our saved iterator is valid in the next run
maybe_edge_iterable = std::shared_ptr<EdgeAccessorIterable>{nullptr},
maybe_current_edge_iter = std::optional<EdgeAccessorIterableIterator>{}](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<size_t> {
// Delay the call of begin() function
// If multiple begins are called before an iteration,
// one iteration will make the rest of iterators be in undefined
// states.
if (!maybe_current_vertex_iter) {
maybe_current_vertex_iter.emplace(vertices_iterable_.begin());
}
auto &current_vertex_iter{*maybe_current_vertex_iter};
size_t local_counter = 0U;
for (; current_vertex_iter != vertices_iterable_.end() &&
(!n || local_counter < *n);
++current_vertex_iter) {
const auto &vertex = *current_vertex_iter;
// If we have a saved iterable from a previous pull
// we need to use the same iterable
if (!maybe_edge_iterable) {
maybe_edge_iterable = std::make_shared<EdgeAccessorIterable>(
vertex.OutEdges(storage::View::OLD));
}
auto &maybe_edges = *maybe_edge_iterable;
CHECK(maybe_edges.HasValue()) << "Invalid database state!";
auto current_edge_iter = maybe_current_edge_iter
? *maybe_current_edge_iter
: maybe_edges->begin();
for (; current_edge_iter != maybe_edges->end() &&
(!n || local_counter < *n);
++current_edge_iter) {
std::ostringstream os;
DumpEdge(&os, dba_, *current_edge_iter);
stream->Result({TypedValue(os.str())});
++local_counter;
}
if (current_edge_iter != maybe_edges->end()) {
maybe_current_edge_iter.emplace(current_edge_iter);
return std::nullopt;
}
maybe_current_edge_iter = std::nullopt;
maybe_edge_iterable = nullptr;
}
if (current_vertex_iter == vertices_iterable_.end()) {
return local_counter;
}
return std::nullopt;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateDropInternalIndexPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "DROP INDEX ON :" << kInternalVertexLabel << "("
<< kInternalPropertyId << ");";
stream->Result({TypedValue(os.str())});
return 1;
}
{
return 0;
};
}
PullPlanDump::PullChunk PullPlanDump::CreateInternalIndexCleanupPullChunk() {
return [this](AnyStream *stream, std::optional<int>) {
if (internal_index_created_) {
std::ostringstream os;
os << "MATCH (u) REMOVE u:" << kInternalVertexLabel << ", u."
<< kInternalPropertyId << ";";
stream->Result({TypedValue(os.str())});
return 1;
}
}
return 0;
};
}
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream) {
PullPlanDump(dba).Pull(stream, {});
}
} // namespace query

View File

@ -4,9 +4,57 @@
#include "query/db_accessor.hpp"
#include "query/stream.hpp"
#include "storage/v2/storage.hpp"
namespace query {
void DumpDatabaseToCypherQueries(query::DbAccessor *dba, AnyStream *stream);
struct PullPlanDump {
explicit PullPlanDump(query::DbAccessor *dba);
/// Pull the dump results lazily
/// @return true if all results were returned, false otherwise
bool Pull(AnyStream *stream, std::optional<int> n);
private:
query::DbAccessor *dba_ = nullptr;
std::optional<storage::IndicesInfo> indices_info_ = std::nullopt;
std::optional<storage::ConstraintsInfo> constraints_info_ = std::nullopt;
using VertexAccessorIterable =
decltype(std::declval<query::DbAccessor>().Vertices(storage::View::OLD));
using VertexAccessorIterableIterator =
decltype(std::declval<VertexAccessorIterable>().begin());
using EdgeAccessorIterable =
decltype(std::declval<VertexAccessor>().OutEdges(storage::View::OLD));
using EdgeAccessorIterableIterator =
decltype(std::declval<EdgeAccessorIterable>().GetValue().begin());
VertexAccessorIterable vertices_iterable_;
bool internal_index_created_ = false;
size_t current_chunk_index_ = 0;
using PullChunk = std::function<std::optional<size_t>(AnyStream *stream,
std::optional<int> n)>;
// We define every part of the dump query in a self contained function.
// Each functions is responsible of keeping track of its execution status.
// If a function did finish its execution, it should return number of results
// it streamed so we know how many rows should be pulled from the next
// function, otherwise std::nullopt is returned.
std::vector<PullChunk> pull_chunks_;
PullChunk CreateLabelIndicesPullChunk();
PullChunk CreateLabelPropertyIndicesPullChunk();
PullChunk CreateExistenceConstraintsPullChunk();
PullChunk CreateUniqueConstraintsPullChunk();
PullChunk CreateInternalIndexPullChunk();
PullChunk CreateVertexPullChunk();
PullChunk CreateEdgePullChunk();
PullChunk CreateDropInternalIndexPullChunk();
PullChunk CreateInternalIndexCleanupPullChunk();
};
} // namespace query

View File

@ -159,4 +159,12 @@ class StreamClauseInMulticommandTxException : public QueryException {
"Stream clause not allowed in multicommand transactions.") {}
};
class InvalidArgumentsException : public QueryException {
public:
InvalidArgumentsException(const std::string &argument_name,
const std::string &message)
: QueryException(fmt::format("Invalid arguments sent: {} - {}",
argument_name, message)) {}
};
} // namespace query

View File

@ -5,6 +5,8 @@
#include <glog/logging.h>
#include "glue/communication.hpp"
#include "query/context.hpp"
#include "query/db_accessor.hpp"
#include "query/dump.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp"
@ -15,9 +17,11 @@
#include "query/plan/planner.hpp"
#include "query/plan/profile.hpp"
#include "query/plan/vertex_count_cache.hpp"
#include "query/typed_value.hpp"
#include "utils/algorithm.hpp"
#include "utils/exceptions.hpp"
#include "utils/flag_validation.hpp"
#include "utils/memory.hpp"
#include "utils/string.hpp"
#include "utils/tsc.hpp"
@ -324,72 +328,155 @@ Interpreter::Interpreter(InterpreterContext *interpreter_context)
CHECK(interpreter_context_) << "Interpreter context must not be NULL";
}
ExecutionContext PullAllPlan(AnyStream *stream, const CachedPlan &plan,
const Parameters &parameters,
const std::vector<Symbol> &output_symbols,
bool is_profile_query,
std::map<std::string, TypedValue> *summary,
DbAccessor *dba,
InterpreterContext *interpreter_context,
utils::MonotonicBufferResource *execution_memory) {
auto cursor = plan.plan().MakeCursor(execution_memory);
Frame frame(plan.symbol_table().max_position(), execution_memory);
namespace {
// Struct for lazy pulling from a vector
struct PullPlanVector {
explicit PullPlanVector(std::vector<std::vector<TypedValue>> values)
: values_(std::move(values)) {}
// @return true if there are more unstreamed elements in vector,
// false otherwise.
bool Pull(AnyStream *stream, std::optional<int> n) {
int local_counter{0};
while (global_counter < values_.size() && (!n || local_counter < n)) {
stream->Result(values_[global_counter]);
++global_counter;
++local_counter;
}
return global_counter == values_.size();
}
private:
int global_counter{0};
std::vector<std::vector<TypedValue>> values_;
};
struct PullPlan {
explicit PullPlan(std::shared_ptr<CachedPlan> plan,
const Parameters &parameters, bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context,
utils::MonotonicBufferResource *execution_memory);
std::optional<ExecutionContext> Pull(
AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary);
private:
std::shared_ptr<CachedPlan> plan_ = nullptr;
plan::UniqueCursorPtr cursor_ = nullptr;
Frame frame_;
ExecutionContext ctx_;
// As it's possible to query execution using multiple pulls
// we need the keep track of the total execution time across
// those pulls by accumulating the execution time.
std::chrono::duration<double> execution_time_{0};
// To pull the results from a query we call the `Pull` method on
// the cursor which saves the results in a Frame.
// Becuase we can't find out if there are some saved results in a frame,
// and the cursor cannot deduce if the next pull will have a result,
// we have to keep track of any unsent results from previous `PullPlan::Pull`
// manually by using this flag.
bool has_unsent_results_ = false;
};
PullPlan::PullPlan(const std::shared_ptr<CachedPlan> plan,
const Parameters &parameters, const bool is_profile_query,
DbAccessor *dba, InterpreterContext *interpreter_context,
utils::MonotonicBufferResource *execution_memory)
: plan_(plan),
cursor_(plan->plan().MakeCursor(execution_memory)),
frame_(plan->symbol_table().max_position(), execution_memory) {
ctx_.db_accessor = dba;
ctx_.symbol_table = plan->symbol_table();
ctx_.evaluation_context.timestamp =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
ctx_.evaluation_context.parameters = parameters;
ctx_.evaluation_context.properties =
NamesToProperties(plan->ast_storage().properties_, dba);
ctx_.evaluation_context.labels =
NamesToLabels(plan->ast_storage().labels_, dba);
ctx_.execution_tsc_timer =
utils::TSCTimer(interpreter_context->tsc_frequency);
ctx_.max_execution_time_sec = interpreter_context->execution_timeout_sec;
ctx_.is_shutting_down = &interpreter_context->is_shutting_down;
ctx_.is_profile_query = is_profile_query;
}
std::optional<ExecutionContext> PullPlan::Pull(
AnyStream *stream, std::optional<int> n,
const std::vector<Symbol> &output_symbols,
std::map<std::string, TypedValue> *summary) {
// Set up temporary memory for a single Pull. Initial memory comes from the
// stack. 256 KiB should fit on the stack and should be more than enough for a
// single `Pull`.
constexpr size_t stack_size = 256 * 1024;
char stack_data[stack_size];
ExecutionContext ctx;
ctx.db_accessor = dba;
ctx.symbol_table = plan.symbol_table();
ctx.evaluation_context.timestamp =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
ctx.evaluation_context.parameters = parameters;
ctx.evaluation_context.properties =
NamesToProperties(plan.ast_storage().properties_, dba);
ctx.evaluation_context.labels =
NamesToLabels(plan.ast_storage().labels_, dba);
ctx.execution_tsc_timer = utils::TSCTimer(interpreter_context->tsc_frequency);
ctx.max_execution_time_sec = interpreter_context->execution_timeout_sec;
ctx.is_shutting_down = &interpreter_context->is_shutting_down;
ctx.is_profile_query = is_profile_query;
utils::Timer timer;
while (true) {
// Returns true if a result was pulled.
const auto pull_result = [&]() -> bool {
utils::MonotonicBufferResource monotonic_memory(&stack_data[0], stack_size);
// TODO (mferencevic): Tune the parameters accordingly.
utils::PoolResource pool_memory(128, 1024, &monotonic_memory);
ctx.evaluation_context.memory = &pool_memory;
ctx_.evaluation_context.memory = &pool_memory;
if (!cursor->Pull(frame, ctx)) {
return cursor_->Pull(frame_, ctx_);
};
const auto stream_values = [&]() {
// TODO: The streamed values should also probably use the above memory.
std::vector<TypedValue> values;
values.reserve(output_symbols.size());
for (const auto &symbol : output_symbols) {
values.emplace_back(frame_[symbol]);
}
stream->Result(values);
};
// Get the execution time of all possible result pulls and streams.
utils::Timer timer;
int i = 0;
if (has_unsent_results_ && !output_symbols.empty()) {
// stream unsent results from previous pull
stream_values();
++i;
}
for (; !n || i < n; ++i) {
if (!pull_result()) {
break;
}
if (!output_symbols.empty()) {
// TODO: The streamed values should also probably use the above memory.
std::vector<TypedValue> values;
values.reserve(output_symbols.size());
for (const auto &symbol : output_symbols) {
values.emplace_back(frame[symbol]);
}
stream->Result(values);
stream_values();
}
}
auto execution_time = timer.Elapsed();
ctx.profile_execution_time = execution_time;
summary->insert_or_assign("plan_execution_time", execution_time.count());
cursor->Shutdown();
// If we finished because we streamed the requested n results,
// we try to pull the next result to see if there is more.
// If there is additional result, we leave the pulled result in the frame
// and set the flag to true.
has_unsent_results_ = i == n && pull_result();
return ctx;
execution_time_ += timer.Elapsed();
if (has_unsent_results_) {
return std::nullopt;
}
summary->insert_or_assign("plan_execution_time", execution_time_.count());
cursor_->Shutdown();
ctx_.profile_execution_time = execution_time_;
return ctx_;
}
} // namespace
/**
* Convert a parsed *Cypher* query's AST into a logical plan.
@ -468,7 +555,7 @@ PreparedQuery Interpreter::PrepareTransactionQuery(
try {
Commit();
} catch (const utils::BasicException &) {
AbortCommand();
AbortCommand(nullptr);
throw;
}
@ -489,10 +576,11 @@ PreparedQuery Interpreter::PrepareTransactionQuery(
LOG(FATAL) << "Should not get here -- unknown transaction query!";
}
return {{}, {}, [handler = std::move(handler)](AnyStream *) {
handler();
return QueryHandlerResult::NOTHING;
}};
return {
{}, {}, [handler = std::move(handler)](AnyStream *, std::optional<int>) {
handler();
return QueryHandlerResult::NOTHING;
}};
}
PreparedQuery PrepareCypherQuery(
@ -521,14 +609,19 @@ PreparedQuery PrepareCypherQuery(
.first);
}
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba,
interpreter_context, execution_memory);
return PreparedQuery{
std::move(header), std::move(parsed_query.required_privileges),
[plan = std::move(plan), parameters = std::move(parsed_query.parameters),
output_symbols = std::move(output_symbols), summary, dba,
interpreter_context, execution_memory](AnyStream *stream) {
PullAllPlan(stream, *plan, parameters, output_symbols, false, summary,
dba, interpreter_context, execution_memory);
return QueryHandlerResult::COMMIT;
[pull_plan = std::move(pull_plan),
output_symbols = std::move(output_symbols),
summary](AnyStream *stream,
std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n, output_symbols, summary)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
}};
}
@ -537,7 +630,6 @@ PreparedQuery PrepareExplainQuery(
InterpreterContext *interpreter_context, DbAccessor *dba,
utils::MonotonicBufferResource *execution_memory) {
const std::string kExplainQueryStart = "explain ";
CHECK(
utils::StartsWith(utils::ToLowerCase(parsed_query.stripped_query.query()),
kExplainQueryStart))
@ -577,12 +669,14 @@ PreparedQuery PrepareExplainQuery(
return PreparedQuery{
{"QUERY PLAN"},
std::move(parsed_query.required_privileges),
[rows = std::move(printed_plan_rows)](AnyStream *stream) {
for (const auto &row : rows) {
stream->Result(row);
[pull_plan =
std::make_shared<PullPlanVector>(std::move(printed_plan_rows))](
AnyStream *stream,
std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return QueryHandlerResult::ABORT;
return std::nullopt;
}};
}
@ -643,33 +737,50 @@ PreparedQuery PrepareProfileQuery(
std::move(parsed_query.required_privileges),
[plan = std::move(cypher_query_plan),
parameters = std::move(parsed_inner_query.parameters), summary, dba,
interpreter_context, execution_memory](AnyStream *stream) {
interpreter_context, execution_memory,
// We want to execute the query we are profiling lazily, so we delay
// the construction of the corresponding context.
ctx = std::optional<ExecutionContext>{},
pull_plan = std::shared_ptr<PullPlanVector>(nullptr)](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
// No output symbols are given so that nothing is streamed.
auto ctx = PullAllPlan(stream, *plan, parameters, {}, true, summary,
dba, interpreter_context, execution_memory);
for (const auto &row :
ProfilingStatsToTable(ctx.stats, ctx.profile_execution_time)) {
stream->Result(row);
if (!ctx) {
ctx = PullPlan(plan, parameters, true, dba, interpreter_context,
execution_memory)
.Pull(stream, {}, {}, summary);
pull_plan = std::make_shared<PullPlanVector>(
ProfilingStatsToTable(ctx->stats, ctx->profile_execution_time));
}
summary->insert_or_assign(
"profile",
ProfilingStatsToJson(ctx.stats, ctx.profile_execution_time).dump());
CHECK(ctx) << "Failed to execute the query!";
return QueryHandlerResult::ABORT;
if (pull_plan->Pull(stream, n)) {
summary->insert_or_assign(
"profile",
ProfilingStatsToJson(ctx->stats, ctx->profile_execution_time)
.dump());
return QueryHandlerResult::ABORT;
}
return std::nullopt;
}};
}
PreparedQuery PrepareDumpQuery(
ParsedQuery parsed_query, std::map<std::string, TypedValue> *summary,
DbAccessor *dba, utils::MonotonicBufferResource *execution_memory) {
return PreparedQuery{{"QUERY"},
std::move(parsed_query.required_privileges),
[dba](AnyStream *stream) {
DumpDatabaseToCypherQueries(dba, stream);
return QueryHandlerResult::COMMIT;
}};
return PreparedQuery{
{"QUERY"},
std::move(parsed_query.required_privileges),
[pull_plan = std::make_shared<PullPlanDump>(dba)](
AnyStream *stream,
std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n)) {
return QueryHandlerResult::COMMIT;
}
return std::nullopt;
}};
}
PreparedQuery PrepareIndexQuery(
@ -732,12 +843,13 @@ PreparedQuery PrepareIndexQuery(
}
}
return PreparedQuery{{},
std::move(parsed_query.required_privileges),
[handler = std::move(handler)](AnyStream *stream) {
handler();
return QueryHandlerResult::NOTHING;
}};
return PreparedQuery{
{},
std::move(parsed_query.required_privileges),
[handler = std::move(handler)](AnyStream *stream, std::optional<int>) {
handler();
return QueryHandlerResult::NOTHING;
}};
}
PreparedQuery PrepareAuthQuery(
@ -767,16 +879,20 @@ PreparedQuery PrepareAuthQuery(
[fn = callback.fn](Frame *, ExecutionContext *) { return fn(); }),
0.0, AstStorage{}, symbol_table));
auto pull_plan =
std::make_shared<PullPlan>(plan, parsed_query.parameters, false, dba,
interpreter_context, execution_memory);
return PreparedQuery{
callback.header, std::move(parsed_query.required_privileges),
[callback = std::move(callback), plan = std::move(plan),
parameters = std::move(parsed_query.parameters),
output_symbols = std::move(output_symbols), summary, dba,
interpreter_context, execution_memory](AnyStream *stream) {
PullAllPlan(stream, *plan, parameters, output_symbols, false, summary,
dba, interpreter_context, execution_memory);
return callback.should_abort_query ? QueryHandlerResult::ABORT
: QueryHandlerResult::COMMIT;
[pull_plan = std::move(pull_plan), callback = std::move(callback),
output_symbols = std::move(output_symbols),
summary](AnyStream *stream,
std::optional<int> n) -> std::optional<QueryHandlerResult> {
if (pull_plan->Pull(stream, n, output_symbols, summary)) {
return callback.should_abort_query ? QueryHandlerResult::ABORT
: QueryHandlerResult::COMMIT;
}
return std::nullopt;
}};
}
@ -859,17 +975,23 @@ PreparedQuery PrepareInfoQuery(
break;
}
return PreparedQuery{std::move(header),
std::move(parsed_query.required_privileges),
[handler = std::move(handler)](AnyStream *stream) {
auto [results, action] = handler();
return PreparedQuery{
std::move(header), std::move(parsed_query.required_privileges),
[handler = std::move(handler), action = QueryHandlerResult::NOTHING,
pull_plan = std::shared_ptr<PullPlanVector>(nullptr)](
AnyStream *stream,
std::optional<int> n) mutable -> std::optional<QueryHandlerResult> {
if (!pull_plan) {
auto [results, action_on_complete] = handler();
action = action_on_complete;
pull_plan = std::make_shared<PullPlanVector>(std::move(results));
}
for (const auto &result : results) {
stream->Result(result);
}
return action;
}};
if (pull_plan->Pull(stream, n)) {
return action;
}
return std::nullopt;
}};
}
PreparedQuery PrepareConstraintQuery(
@ -911,9 +1033,8 @@ PreparedQuery PrepareConstraintQuery(
auto label_name =
interpreter_context->db->LabelToName(violation.label);
CHECK(violation.properties.size() == 1U);
auto property_name =
interpreter_context->db->PropertyToName(
*violation.properties.begin());
auto property_name = interpreter_context->db->PropertyToName(
*violation.properties.begin());
throw QueryRuntimeException(
"Unable to create existence constraint :{}({}), because an "
"existing node violates it.",
@ -1024,29 +1145,55 @@ PreparedQuery PrepareConstraintQuery(
} break;
}
return PreparedQuery{{},
std::move(parsed_query.required_privileges),
[handler = std::move(handler)](AnyStream *stream) {
handler();
return QueryHandlerResult::COMMIT;
}};
return PreparedQuery{
{},
std::move(parsed_query.required_privileges),
[handler = std::move(handler)](AnyStream *stream, std::optional<int> n) {
handler();
return QueryHandlerResult::COMMIT;
}};
}
std::pair<std::vector<std::string>, std::vector<query::AuthQuery::Privilege>>
Interpreter::Prepare(
void Interpreter::BeginTransaction() {
const auto prepared_query = PrepareTransactionQuery("BEGIN");
prepared_query.query_handler(nullptr, {});
}
void Interpreter::CommitTransaction() {
const auto prepared_query = PrepareTransactionQuery("COMMIT");
prepared_query.query_handler(nullptr, {});
}
void Interpreter::RollbackTransaction() {
const auto prepared_query = PrepareTransactionQuery("ROLLBACK");
prepared_query.query_handler(nullptr, {});
}
Interpreter::PrepareResult Interpreter::Prepare(
const std::string &query_string,
const std::map<std::string, storage::PropertyValue> &params) {
// Clear the last prepared query.
prepared_query_ = std::nullopt;
execution_memory_.Release();
if (!in_explicit_transaction_) {
// TODO(antonio2368): Should this throw?
CHECK(!ActiveQueryExecutions()) << "Only one active execution allowed "
"while not in explicit transaction!";
query_executions_.clear();
}
query_executions_.emplace_back(std::make_unique<QueryExecution>());
auto &query_execution = query_executions_.back();
std::optional<int> qid = in_explicit_transaction_
? static_cast<int>(query_executions_.size() - 1)
: std::optional<int>{};
// Handle transaction control queries.
auto query_upper = utils::Trim(utils::ToUpperCase(query_string));
if (query_upper == "BEGIN" || query_upper == "COMMIT" ||
query_upper == "ROLLBACK") {
prepared_query_ = PrepareTransactionQuery(query_upper);
return {prepared_query_->header, prepared_query_->privileges};
query_execution->prepared_query.emplace(
PrepareTransactionQuery(query_upper));
return {query_execution->prepared_query->header,
query_execution->prepared_query->privileges, qid};
}
// All queries other than transaction control queries advance the command in
@ -1057,28 +1204,26 @@ Interpreter::Prepare(
// If we're not in an explicit transaction block and we have an open
// transaction, abort it since we're about to prepare a new query.
else if (db_accessor_) {
AbortCommand();
AbortCommand(&query_execution);
}
try {
summary_ = {};
// TODO: Set summary['type'] based on transaction metadata. The type can't
// be determined based only on the toplevel logical operator -- for example
// `MATCH DELETE RETURN`, which is a write query, will have `Produce` as its
// toplevel operator). For now we always set "rw" because something must be
// set, but it doesn't have to be correct (for Bolt clients).
summary_["type"] = "rw";
query_execution->summary["type"] = "rw";
// Set a default cost estimate of 0. Individual queries can overwrite this
// field with an improved estimate.
summary_["cost_estimate"] = 0.0;
query_execution->summary["cost_estimate"] = 0.0;
utils::Timer parsing_timer;
ParsedQuery parsed_query =
ParseQuery(query_string, params, &interpreter_context_->ast_cache,
&interpreter_context_->antlr_lock);
summary_["parsing_time"] = parsing_timer.Elapsed().count();
query_execution->summary["parsing_time"] = parsing_timer.Elapsed().count();
// Some queries require an active transaction in order to be prepared.
if (!in_explicit_transaction_ &&
@ -1094,54 +1239,61 @@ Interpreter::Prepare(
PreparedQuery prepared_query;
if (utils::Downcast<CypherQuery>(parsed_query.query)) {
prepared_query = PrepareCypherQuery(
std::move(parsed_query), &summary_, interpreter_context_,
&*execution_db_accessor_, &execution_memory_);
prepared_query =
PrepareCypherQuery(std::move(parsed_query), &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory);
} else if (utils::Downcast<ExplainQuery>(parsed_query.query)) {
prepared_query = PrepareExplainQuery(
std::move(parsed_query), &summary_, interpreter_context_,
&*execution_db_accessor_, &execution_memory_);
std::move(parsed_query), &query_execution->summary,
interpreter_context_, &*execution_db_accessor_,
&query_execution->execution_memory);
} else if (utils::Downcast<ProfileQuery>(parsed_query.query)) {
prepared_query = PrepareProfileQuery(
std::move(parsed_query), in_explicit_transaction_, &summary_,
interpreter_context_, &*execution_db_accessor_, &execution_memory_);
std::move(parsed_query), in_explicit_transaction_,
&query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory);
} else if (utils::Downcast<DumpQuery>(parsed_query.query)) {
prepared_query =
PrepareDumpQuery(std::move(parsed_query), &summary_,
&*execution_db_accessor_, &execution_memory_);
prepared_query = PrepareDumpQuery(
std::move(parsed_query), &query_execution->summary,
&*execution_db_accessor_, &query_execution->execution_memory);
} else if (utils::Downcast<IndexQuery>(parsed_query.query)) {
prepared_query = PrepareIndexQuery(
std::move(parsed_query), in_explicit_transaction_, &summary_,
interpreter_context_, &execution_memory_);
prepared_query =
PrepareIndexQuery(std::move(parsed_query), in_explicit_transaction_,
&query_execution->summary, interpreter_context_,
&query_execution->execution_memory);
} else if (utils::Downcast<AuthQuery>(parsed_query.query)) {
prepared_query = PrepareAuthQuery(
std::move(parsed_query), in_explicit_transaction_, &summary_,
interpreter_context_, &*execution_db_accessor_, &execution_memory_);
std::move(parsed_query), in_explicit_transaction_,
&query_execution->summary, interpreter_context_,
&*execution_db_accessor_, &query_execution->execution_memory);
} else if (utils::Downcast<InfoQuery>(parsed_query.query)) {
prepared_query = PrepareInfoQuery(
std::move(parsed_query), in_explicit_transaction_, &summary_,
interpreter_context_, interpreter_context_->db, &execution_memory_);
std::move(parsed_query), in_explicit_transaction_,
&query_execution->summary, interpreter_context_,
interpreter_context_->db, &query_execution->execution_memory);
} else if (utils::Downcast<ConstraintQuery>(parsed_query.query)) {
prepared_query = PrepareConstraintQuery(
std::move(parsed_query), in_explicit_transaction_, &summary_,
interpreter_context_, &execution_memory_);
std::move(parsed_query), in_explicit_transaction_,
&query_execution->summary, interpreter_context_,
&query_execution->execution_memory);
} else {
LOG(FATAL) << "Should not get here -- unknown query type!";
}
summary_["planning_time"] = planning_timer.Elapsed().count();
prepared_query_ = std::move(prepared_query);
query_execution->summary["planning_time"] =
planning_timer.Elapsed().count();
query_execution->prepared_query.emplace(std::move(prepared_query));
return {prepared_query_->header, prepared_query_->privileges};
return {query_execution->prepared_query->header,
query_execution->prepared_query->privileges, qid};
} catch (const utils::BasicException &) {
AbortCommand();
AbortCommand(&query_execution);
throw;
}
}
void Interpreter::Abort() {
prepared_query_ = std::nullopt;
execution_memory_.Release();
expect_rollback_ = false;
in_explicit_transaction_ = false;
if (!db_accessor_) return;
@ -1151,8 +1303,11 @@ void Interpreter::Abort() {
}
void Interpreter::Commit() {
prepared_query_ = std::nullopt;
execution_memory_.Release();
// It's possible that some queries did not finish because the user did
// not pull all of the results from the query.
// For now, we will not check if there are some unfinished queries.
// We should document clearly that all results should be pulled to complete
// a query.
if (!db_accessor_) return;
auto maybe_constraint_violation = db_accessor_->Commit();
if (maybe_constraint_violation.HasError()) {
@ -1194,15 +1349,15 @@ void Interpreter::Commit() {
}
void Interpreter::AdvanceCommand() {
prepared_query_ = std::nullopt;
execution_memory_.Release();
if (!db_accessor_) return;
db_accessor_->AdvanceCommand();
}
void Interpreter::AbortCommand() {
prepared_query_ = std::nullopt;
execution_memory_.Release();
void Interpreter::AbortCommand(
std::unique_ptr<QueryExecution> *query_execution) {
if (query_execution) {
query_execution->reset(nullptr);
}
if (in_explicit_transaction_) {
expect_rollback_ = true;
} else {

View File

@ -4,12 +4,14 @@
#include "query/context.hpp"
#include "query/db_accessor.hpp"
#include "query/exceptions.hpp"
#include "query/frontend/ast/ast.hpp"
#include "query/frontend/ast/cypher_main_visitor.hpp"
#include "query/frontend/stripped.hpp"
#include "query/interpret/frame.hpp"
#include "query/plan/operator.hpp"
#include "query/stream.hpp"
#include "query/typed_value.hpp"
#include "utils/memory.hpp"
#include "utils/skip_list.hpp"
#include "utils/spin_lock.hpp"
@ -102,7 +104,9 @@ enum class QueryHandlerResult { COMMIT, ABORT, NOTHING };
struct PreparedQuery {
std::vector<std::string> header;
std::vector<AuthQuery::Privilege> privileges;
std::function<QueryHandlerResult(AnyStream *stream)> query_handler;
std::function<std::optional<QueryHandlerResult>(AnyStream *stream,
std::optional<int> n)>
query_handler;
};
// TODO: Maybe this should move to query/plan/planner.
@ -182,8 +186,7 @@ struct PlanCacheEntry {
* been passed to an `Interpreter` instance.
*/
struct InterpreterContext {
explicit InterpreterContext(storage::Storage *db)
: db(db) {
explicit InterpreterContext(storage::Storage *db) : db(db) {
CHECK(db) << "Storage must not be NULL";
}
@ -227,18 +230,23 @@ class Interpreter final {
Interpreter &operator=(Interpreter &&) = delete;
~Interpreter() { Abort(); }
struct PrepareResult {
std::vector<std::string> headers;
std::vector<query::AuthQuery::Privilege> privileges;
std::optional<int> qid;
};
/**
* Prepare a query for execution.
*
* To prepare a query for execution means to preprocess the query and adjust
* the state of the `Interpreter` in such a way so that the next call to
* `PullAll` executes the query.
* Preparing a query means to preprocess the query and save it for
* future calls of `Pull`.
*
* @throw query::QueryException
*/
std::pair<std::vector<std::string>, std::vector<query::AuthQuery::Privilege>>
Prepare(const std::string &query,
const std::map<std::string, storage::PropertyValue> &params);
PrepareResult Prepare(
const std::string &query,
const std::map<std::string, storage::PropertyValue> &params);
/**
* Execute the last prepared query and stream *all* of the results into the
@ -257,7 +265,37 @@ class Interpreter final {
* @throw query::QueryException
*/
template <typename TStream>
std::map<std::string, TypedValue> PullAll(TStream *result_stream);
std::map<std::string, TypedValue> PullAll(TStream *result_stream) {
return Pull(result_stream);
}
/**
* Execute a prepared query and stream result into the given stream.
*
* TStream should be a type implementing the `Stream` concept, i.e. it should
* contain the member function `void Result(const std::vector<TypedValue> &)`.
* The provided vector argument is valid only for the duration of the call to
* `Result`. The stream should make an explicit copy if it wants to use it
* further.
*
* @param n If set, amount of rows to be pulled from result,
* otherwise all the rows are pulled.
* @param qid If set, id of the query from which the result should be pulled,
* otherwise the last query should be used.
*
* @throw utils::BasicException
* @throw query::QueryException
*/
template <typename TStream>
std::map<std::string, TypedValue> Pull(TStream *result_stream,
std::optional<int> n = {},
std::optional<int> qid = {});
void BeginTransaction();
void CommitTransaction();
void RollbackTransaction();
/**
* Abort the current multicommand transaction.
@ -265,61 +303,141 @@ class Interpreter final {
void Abort();
private:
struct QueryExecution {
std::optional<PreparedQuery> prepared_query;
utils::MonotonicBufferResource execution_memory{kExecutionMemoryBlockSize};
std::map<std::string, TypedValue> summary;
explicit QueryExecution() = default;
QueryExecution(const QueryExecution &) = delete;
QueryExecution(QueryExecution &&) = default;
QueryExecution &operator=(const QueryExecution &) = delete;
QueryExecution &operator=(QueryExecution &&) = default;
~QueryExecution() {
// We should always release the execution memory AFTER we
// destroy the prepared query which is using that instance
// of execution memory.
prepared_query.reset();
execution_memory.Release();
}
};
// Interpreter supports multiple prepared queries at the same time.
// The client can reference a specific query for pull using an arbitrary qid
// which is in our case the index of the query in the vector.
// To simplify the handling of the qid we avoid modifying the vector if it
// affects the position of the currently running queries in any way.
// For example, we cannot delete the prepared query from the vector because
// every prepared query after the deleted one will be moved by one place
// making their qid not equal to the their index inside the vector.
// To avoid this, we use unique_ptr with which we manualy control construction
// and deletion of a single query execution, i.e. when a query finishes,
// we reset the corresponding unique_ptr.
std::vector<std::unique_ptr<QueryExecution>> query_executions_;
InterpreterContext *interpreter_context_;
std::optional<PreparedQuery> prepared_query_;
std::map<std::string, TypedValue> summary_;
std::optional<storage::Storage::Accessor> db_accessor_;
std::optional<DbAccessor> execution_db_accessor_;
bool in_explicit_transaction_{false};
bool expect_rollback_{false};
utils::MonotonicBufferResource execution_memory_{kExecutionMemoryBlockSize};
PreparedQuery PrepareTransactionQuery(std::string_view query_upper);
void Commit();
void AdvanceCommand();
void AbortCommand();
void AbortCommand(std::unique_ptr<QueryExecution> *query_execution);
size_t ActiveQueryExecutions() {
return std::count_if(query_executions_.begin(), query_executions_.end(),
[](const auto &execution) {
return execution && execution->prepared_query;
});
}
};
template <typename TStream>
std::map<std::string, TypedValue> Interpreter::PullAll(TStream *result_stream) {
CHECK(prepared_query_) << "Trying to call PullAll without a prepared query";
std::map<std::string, TypedValue> Interpreter::Pull(TStream *result_stream,
std::optional<int> n,
std::optional<int> qid) {
CHECK(in_explicit_transaction_ || !qid)
<< "qid can be only used in explicit transaction!";
const int qid_value =
qid ? *qid : static_cast<int>(query_executions_.size() - 1);
if (qid_value < 0 || qid_value >= query_executions_.size()) {
throw InvalidArgumentsException("qid",
"Query with specified ID does not exist!");
}
if (n && n < 0) {
throw InvalidArgumentsException("n",
"Cannot fetch negative number of results!");
}
auto &query_execution = query_executions_[qid_value];
CHECK(query_execution && query_execution->prepared_query)
<< "Query already finished executing!";
// Each prepared query has its own summary so we need to somehow preserve
// it after it finishes executing because it gets destroyed alongside
// the prepared query and its execution memory.
std::optional<std::map<std::string, TypedValue>> maybe_summary;
try {
// Wrap the (statically polymorphic) stream type into a common type which
// the handler knows.
AnyStream stream{result_stream, &execution_memory_};
QueryHandlerResult res = prepared_query_->query_handler(&stream);
// Erase the prepared query in order to enforce that every call to `PullAll`
// must be preceded by a call to `Prepare`.
prepared_query_ = std::nullopt;
AnyStream stream{result_stream, &query_execution->execution_memory};
const auto maybe_res =
query_execution->prepared_query->query_handler(&stream, n);
// Stream is using execution memory of the query_execution which
// can be deleted after its execution so the stream should be cleared
// first.
stream.~AnyStream();
if (!in_explicit_transaction_) {
switch (res) {
case QueryHandlerResult::COMMIT:
Commit();
break;
case QueryHandlerResult::ABORT:
Abort();
break;
case QueryHandlerResult::NOTHING:
// The only cases in which we have nothing to do are those where we're
// either in an explicit transaction or the query is such that a
// transaction wasn't started on a call to `Prepare()`.
CHECK(in_explicit_transaction_ || !db_accessor_);
break;
// If the query finished executing, we have received a value which tells
// us what to do after.
if (maybe_res) {
// Save its summary
maybe_summary.emplace(std::move(query_execution->summary));
if (!in_explicit_transaction_) {
switch (*maybe_res) {
case QueryHandlerResult::COMMIT:
Commit();
break;
case QueryHandlerResult::ABORT:
Abort();
break;
case QueryHandlerResult::NOTHING:
// The only cases in which we have nothing to do are those where
// we're either in an explicit transaction or the query is such that
// a transaction wasn't started on a call to `Prepare()`.
CHECK(in_explicit_transaction_ || !db_accessor_);
break;
}
// As the transaction is done we can clear all the executions
query_executions_.clear();
} else {
// We can only clear this execution as some of the queries
// in the transaction can be in unfinished state
query_execution.reset(nullptr);
}
}
} catch (const ExplicitTransactionUsageException &) {
// Just let the exception propagate for error reporting purposes, but don't
// abort the current command.
query_execution.reset(nullptr);
throw;
} catch (const utils::BasicException &) {
AbortCommand();
AbortCommand(&query_execution);
throw;
}
return summary_;
}
if (maybe_summary) {
// return the execution summary
maybe_summary->insert_or_assign("has_more", false);
return std::move(*maybe_summary);
}
// don't return the execution summary as it's not finished
return {{"has_more", TypedValue(true)}};
}
} // namespace query

View File

@ -1,15 +1,18 @@
# csharp driver
csharp/*.dll
csharp/*.exe
csharp/driver.*
csharp/**/*.dll
csharp/**/*.exe
csharp/**/driver.*
csharp/**/bin
csharp/**/obj
csharp/**/build
# java driver
java/*.class
java/*.jar
java/**/*.class
java/**/*.jar
# javascript driver
javascript/node_modules
javascript/package-lock.json
javascript/**/node_modules
javascript/**/package-lock.json
# python driver
python/ve3

View File

@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Neo4j.Driver.Simple" Version="4.1.1" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,20 @@
using System;
using System.Linq;
using Neo4j.Driver;
public class Basic {
public static void Main(string[] args) {
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, (ConfigBuilder builder) => builder.WithEncryptionLevel(EncryptionLevel.None)))
using(var session = driver.Session())
{
session.Run("MATCH (n) DETACH DELETE n").Consume();
session.Run("CREATE (alice:Person {name: \"Alice\", age: 22})").Consume();
var result = session.Run("MATCH (n) RETURN n").First();
var alice = (INode) result["n"];
Console.WriteLine(alice["name"]);
Console.WriteLine(string.Join(", ", alice.Labels));
Console.WriteLine(alice["age"]);
}
Console.WriteLine("All ok!");
}
}

View File

@ -0,0 +1,80 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Neo4j.Driver;
public class Transactions
{
public static void Main(string[] args) {
using(var driver = GraphDatabase.Driver("bolt://localhost:7687", AuthTokens.None, (builder) => builder.WithEncryptionLevel(EncryptionLevel.None)))
{
ClearDatabase(driver);
// Wrong query.
try
{
using(var session = driver.Session())
using(var tx = session.BeginTransaction())
{
CreatePerson(tx, "mirko");
// Incorrectly start CREATE
tx.Run("CREATE (").Consume();
CreatePerson(tx, "slavko");
tx.Commit();
}
}
catch (ClientException)
{
Console.WriteLine("Rolled back transaction");
}
Trace.Assert(CountNodes(driver) == 0, "Expected transaction was rolled back.");
// Correct query.
using(var session = driver.Session())
using(var tx = session.BeginTransaction())
{
CreatePerson(tx, "mirka");
CreatePerson(tx, "slavka");
tx.Commit();
}
Trace.Assert(CountNodes(driver) == 2, "Expected 2 created nodes.");
ClearDatabase(driver);
using(var session = driver.Session())
{
// Create a lot of nodes so that the next read takes a long time.
session.Run("UNWIND range(1, 100000) AS i CREATE ()").Consume();
try
{
Console.WriteLine("Running a long read...");
session.Run("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*) AS cnt").Consume();
}
catch (TransientException)
{
Console.WriteLine("Transaction timed out");
}
}
}
Console.WriteLine("All ok!");
}
private static void CreatePerson(ITransaction tx, string name)
{
var parameters = new Dictionary<string, Object>{{"name", name}};
var result = tx.Run("CREATE (person:Person {name: $name}) RETURN person", parameters);
Console.WriteLine("Created: " + ((INode) result.First()["person"])["name"]);
}
private static void ClearDatabase(IDriver driver)
{
using(var session = driver.Session())
session.Run("MATCH (n) DETACH DELETE n").Consume();
}
private static int CountNodes(IDriver driver)
{
using(var session = driver.Session())
{
var result = session.Run("MATCH (n) RETURN COUNT(*) AS cnt");
return Convert.ToInt32(result.First()["cnt"]);
}
}
}

View File

@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>netcoreapp2.1</TargetFramework>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Neo4j.Driver.Simple" Version="4.1.1" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,22 @@
#!/bin/bash -e
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$DIR"
# check if dotnet-sdk-2.1 is installed
for i in dotnet; do
if ! which $i >/dev/null; then
echo "Please install $i!"
exit 1
fi
done
for i in *; do
if [ ! -d $i ]; then
continue
fi
pushd $i
dotnet publish -c release --self-contained --runtime linux-x64 --framework netcoreapp2.1 -o build/
./build/$i
popd
done;

View File

@ -0,0 +1,67 @@
package main
import "github.com/neo4j/neo4j-go-driver/neo4j"
import "fmt"
import "log"
func handle_error(err error) {
log.Fatal("Error occured: %s", err)
}
func main() {
configForNeo4j40 := func(conf *neo4j.Config) { conf.Encrypted = false }
driver, err := neo4j.NewDriver("bolt://localhost:7687", neo4j.BasicAuth("", "", ""), configForNeo4j40)
if err != nil {
log.Fatal("An error occurred opening conn: %s", err)
}
defer driver.Close()
sessionConfig := neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite}
session, err := driver.NewSession(sessionConfig)
if err != nil {
log.Fatal("An error occured while creating a session: %s", err)
}
defer session.Close()
result, err := session.Run("MATCH (n) DETACH DELETE n", map[string]interface{}{})
if err != nil {
handle_error(err)
}
_, err = result.Consume()
if err != nil {
handle_error(err)
}
result, err = session.Run(`CREATE (alice:Person {name: "Alice", age: 22})`, map[string]interface{}{})
if err != nil {
handle_error(err)
}
_, err = result.Consume()
if err != nil {
handle_error(err)
}
result, err = session.Run("MATCH (n) RETURN n", map[string]interface{}{})
if err != nil {
handle_error(err)
}
if !result.Next() {
log.Fatal("Missing result")
}
node_record, has_column := result.Record().Get("n")
if !has_column {
log.Fatal("Wrong result returned")
}
node_value := node_record.(neo4j.Node)
fmt.Println(node_value.Props()["name"])
fmt.Println(node_value.Labels())
fmt.Println(node_value.Props()["age"])
fmt.Println("All ok!")
}

View File

@ -0,0 +1,5 @@
module bolt-test
go 1.15
require github.com/neo4j/neo4j-go-driver v1.8.3

View File

@ -0,0 +1,29 @@
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/neo4j/neo4j-go-driver v1.8.3 h1:yfuo9YBAlezdIiogu92GwEir/81RD81dNwS5mY/wAIk=
github.com/neo4j/neo4j-go-driver v1.8.3/go.mod h1:ncO5VaFWh0Nrt+4KT4mOZboaczBZcLuHrG+/sUeP8gI=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.9.0/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=

14
tests/drivers/go/v4/run.sh Executable file
View File

@ -0,0 +1,14 @@
#!/bin/bash -e
# check if go is installed
for i in go; do
if ! which $i >/dev/null; then
echo "Please install $i!"
exit 1
fi
done
go get github.com/neo4j/neo4j-go-driver/neo4j
go run basic.go
go run transactions.go

View File

@ -0,0 +1,94 @@
package main
import "github.com/neo4j/neo4j-go-driver/neo4j"
import "log"
import "fmt"
func handle_error(err error) {
log.Fatal("Error occured: %s", err)
}
func create_person(tx neo4j.Transaction, name string) interface{} {
result, err := tx.Run("CREATE (a:Person {name: $name}) RETURN a", map[string]interface{}{
"name": name,
})
if err != nil {
handle_error(err)
}
if !result.Next() {
log.Fatal("Missing results!");
}
node, has_column := result.Record().Get("a")
if !has_column {
log.Fatal("Wrong results")
}
node_value := node.(neo4j.Node)
return node_value.Props()["name"]
}
func main() {
configForNeo4j40 := func(conf *neo4j.Config) { conf.Encrypted = false }
driver, err := neo4j.NewDriver("bolt://localhost:7687", neo4j.BasicAuth("", "", ""), configForNeo4j40)
if err != nil {
log.Fatal("An error occurred opening conn: %s", err)
}
defer driver.Close()
sessionConfig := neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite}
session, err := driver.NewSession(sessionConfig)
if err != nil {
log.Fatal("An error occured while creating a session: %s", err)
}
defer session.Close()
session.WriteTransaction(func(tx neo4j.Transaction) (interface{}, error) {
fmt.Println(create_person(tx, "mirko"))
result, err := tx.Run("CREATE (", map[string]interface{}{})
if err != nil {
handle_error(err)
}
_, err = result.Consume()
if err == nil {
log.Fatal("The query should have failed")
} else {
fmt.Println("The query failed as expected")
}
return nil, nil
})
session.WriteTransaction(func(tx neo4j.Transaction) (interface{}, error) {
fmt.Println(create_person(tx, "mirko"))
fmt.Println(create_person(tx, "slavko"))
return nil, nil
})
result, err := session.Run("UNWIND range(1, 100000) AS x CREATE ()", map[string]interface{}{})
if err != nil {
handle_error(err)
}
_, err = result.Consume()
if err != nil {
handle_error(err)
}
session.WriteTransaction(func(tx neo4j.Transaction) (interface{}, error) {
result, err := tx.Run("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*) AS cnt", map[string]interface{}{})
if err != nil {
handle_error(err)
}
_, err = result.Consume()
if err == nil {
log.Fatal("The query should have timed out")
} else {
fmt.Println("The query timed out as expected")
}
return nil, nil
})
fmt.Println("All ok!")
}

View File

@ -0,0 +1,38 @@
import org.neo4j.driver.*;
import org.neo4j.driver.types.*;
import static org.neo4j.driver.Values.parameters;
import java.util.*;
public class Basic {
public static void main(String[] args) {
Config config = Config.builder().withoutEncryption().build();
Driver driver = GraphDatabase.driver( "bolt://localhost:7687", AuthTokens.basic( "neo4j", "1234" ), config );
try ( Session session = driver.session() ) {
Result rs1 = session.run( "MATCH (n) DETACH DELETE n" );
System.out.println( "Database cleared." );
Result rs2 = session.run( "CREATE (alice: Person {name: 'Alice', age: 22})" );
System.out.println( "Record created." );
Result rs3 = session.run( "MATCH (n) RETURN n" );
System.out.println( "Record matched." );
List<Record> records = rs3.list();
Record record = records.get( 0 );
Node node = record.get( "n" ).asNode();
if ( !node.get("name").asString().equals( "Alice" ) || node.get("age").asInt() != 22 ) {
System.out.println( "Data doesn't match!" );
System.exit( 1 );
}
System.out.println( "All ok!" );
}
catch ( Exception e ) {
System.out.println( e );
System.exit( 1 );
}
driver.close();
}
}

View File

@ -0,0 +1,59 @@
/**
* Determines how long could be a query executed
* from Java driver.
*
* Performs binary search until the maximum possible
* query size has found.
*/
import java.util.*;
import org.neo4j.driver.*;
import org.neo4j.driver.types.*;
import static org.neo4j.driver.Values.parameters;
public class MaxQueryLength {
public static void main(String[] args) {
// init driver
Config config = Config.builder().withoutEncryption().build();
Driver driver = GraphDatabase.driver("bolt://localhost:7687",
AuthTokens.basic( "", "" ),
config);
// init query
int property_size = 0;
int min_len = 1;
int max_len = 100000;
String query_template = "CREATE (n {name:\"%s\"})";
int template_size = query_template.length() - 2; // because of %s
// binary search
while (true) {
property_size = (max_len + min_len) / 2;
try (Session session = driver.session()) {
String property_value = new String(new char[property_size])
.replace('\0', 'a');
String query = String.format(query_template, property_value);
session.run(query).consume();
if (min_len == max_len || property_size + 1 > max_len) {
break;
}
min_len = property_size + 1;
}
catch (Exception e) {
System.out.println(String.format(
"Query length: %d; Error: %s",
property_size + template_size, e));
max_len = property_size - 1;
}
}
// final result
System.out.println(
String.format("\nThe max length of a query executed from " +
"Java driver is: %s\n",
property_size + template_size));
// cleanup
driver.close();
}
}

View File

@ -0,0 +1,80 @@
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Session;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Transaction;
import org.neo4j.driver.TransactionWork;
import org.neo4j.driver.Result;
import org.neo4j.driver.Config;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.TransientException;
import java.util.concurrent.TimeUnit;
import static org.neo4j.driver.Values.parameters;
import java.util.*;
public class Transactions {
public static String createPerson(Transaction tx, String name) {
Result result = tx.run("CREATE (a:Person {name: $name}) RETURN a.name", parameters("name", name));
return result.single().get(0).asString();
}
public static void main(String[] args) {
Config config = Config.builder().withoutEncryption().withMaxTransactionRetryTime(0, TimeUnit.SECONDS).build();
Driver driver = GraphDatabase.driver( "bolt://localhost:7687", AuthTokens.basic( "neo4j", "1234" ), config );
try ( Session session = driver.session() ) {
try {
session.writeTransaction(new TransactionWork<String>()
{
@Override
public String execute(Transaction tx) {
createPerson(tx, "mirko");
Result result = tx.run("CREATE (");
return result.single().get(0).asString();
}
});
} catch (ClientException e) {
System.out.println(e);
}
session.writeTransaction(new TransactionWork<String>()
{
@Override
public String execute(Transaction tx) {
System.out.println(createPerson(tx, "mirko"));
System.out.println(createPerson(tx, "slavko"));
return "Done";
}
});
System.out.println( "All ok!" );
boolean timed_out = false;
try {
session.writeTransaction(new TransactionWork<String>()
{
@Override
public String execute(Transaction tx) {
Result result = tx.run("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*) AS cnt");
return result.single().get(0).asString();
}
});
} catch (TransientException e) {
timed_out = true;
}
if (timed_out) {
System.out.println("The query timed out as was expected.");
} else {
throw new Exception("The query should have timed out, but it didn't!");
}
}
catch ( Exception e ) {
System.out.println( e );
System.exit( 1 );
}
driver.close();
}
}

32
tests/drivers/java/v4_1/run.sh Executable file
View File

@ -0,0 +1,32 @@
#!/bin/bash -e
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$DIR"
for i in java javac; do
if ! which $i >/dev/null; then
echo "Please install $i!"
exit 1
fi
done
DRIVER=neo4j-java-driver.jar
REACTIVE_STREAM_DEP=reactive-streams.jar
if [ ! -f $DRIVER ]; then
# Driver downloaded from: http://central.maven.org/maven2/org/neo4j/driver/neo4j-java-driver/1.5.2/neo4j-java-driver-1.5.2.jar
wget -nv https://repo1.maven.org/maven2/org/neo4j/driver/neo4j-java-driver/4.1.1/neo4j-java-driver-4.1.1.jar -O $DRIVER || exit 1
fi
if [ ! -f $REACTIVE_STREAM_DEP ]; then
wget -nv https://repo1.maven.org/maven2/org/reactivestreams/reactive-streams/1.0.3/reactive-streams-1.0.3.jar -O $REACTIVE_STREAM_DEP || exit 1
fi
javac -classpath .:$DRIVER:$REACTIVE_STREAM_DEP Basic.java
java -classpath .:$DRIVER:$REACTIVE_STREAM_DEP Basic
javac -classpath .:$DRIVER:$REACTIVE_STREAM_DEP MaxQueryLength.java
java -classpath .:$DRIVER:$REACTIVE_STREAM_DEP MaxQueryLength
javac -classpath .:$DRIVER:$REACTIVE_STREAM_DEP Transactions.java
java -classpath .:$DRIVER:$REACTIVE_STREAM_DEP Transactions

View File

@ -0,0 +1,36 @@
var neo4j = require('neo4j-driver');
var driver = neo4j.driver("bolt://localhost:7687",
neo4j.auth.basic("neo4j", "1234"),
{ encrypted: 'ENCRYPTION_OFF' });
var session = driver.session();
function die() {
session.close();
driver.close();
process.exit(1);
}
function run_query(query, callback) {
var run = session.run(query, {});
run.then(callback).catch(function (error) {
console.log(error);
die();
});
}
run_query("MATCH (n) DETACH DELETE n", function (result) {
console.log("Database cleared.");
run_query("CREATE (alice: Person {name: 'Alice', age: 22})", function (result) {
console.log("Record created.");
run_query("MATCH (n) RETURN n", function (result) {
console.log("Record matched.");
var alice = result.records[0].get("n");
if(alice.labels[0] != "Person" || alice.properties["name"] != "Alice"){
console.log("Data doesn't match!");
die();
}
console.log("All ok!");
driver.close();
});
});
});

View File

@ -0,0 +1,51 @@
// Determines how long could be a query executed
// from JavaScript driver.
//
// Performs binary search until the maximum possible
// query size has found.
// init driver
var neo4j = require('neo4j-driver');
var driver = neo4j.driver("bolt://localhost:7687",
neo4j.auth.basic("", ""),
{ encrypted: 'ENCRYPTION_OFF' });
// init state
var property_size = 0;
var min_len = 1;
var max_len = 1000000;
// hacking with JS and callbacks concept
function serial_execution() {
var next_size = [Math.floor((min_len + max_len) / 2)];
setInterval(function() {
if (next_size.length > 0) {
property_size = next_size.pop();
var query = "CREATE (n {name:\"" +
(new Array(property_size)).join("a")+ "\"})";
var session = driver.session();
session.run(query, {}).then(function (result) {
console.log("Success with the query length " + query.length);
if (min_len == max_len || property_size + 1 > max_len) {
console.log("\nThe max length of a query from JS driver is: " +
query.length + "\n");
session.close();
driver.close();
process.exit(0);
}
min_len = property_size + 1;
next_size.push(Math.floor((min_len + max_len) / 2));
}).catch(function (error) {
console.log("Failure with the query length " + query.length);
max_len = property_size - 1;
next_size.push(Math.floor((min_len + max_len) / 2));
}).then(function(){
session.close();
});
}
}, 100);
}
// execution
console.log("\nDetermine how long can be a query sent from JavaScript driver.");
serial_execution(); // I don't like JavaScript

View File

@ -0,0 +1,17 @@
#!/bin/bash -e
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$DIR"
if ! which node >/dev/null; then
echo "Please install nodejs!"
exit 1
fi
if [ ! -d node_modules ]; then
# Driver generated with: `npm install neo4j-driver`
npm install neo4j-driver
fi
node basic.js
node max_query_length.js

View File

@ -0,0 +1,25 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from neo4j import GraphDatabase, basic_auth
driver = GraphDatabase.driver("bolt://localhost:7687",
auth=basic_auth("", ""),
encrypted=False)
session = driver.session()
session.run('MATCH (n) DETACH DELETE n').consume()
session.run('CREATE (alice:Person {name: "Alice", age: 22})').consume()
returned_result_set = session.run('MATCH (n) RETURN n')
returned_result = returned_result_set.single()
alice = returned_result["n"]
print(alice['name'])
print(set(alice.labels))
print(alice['age'])
session.close()
driver.close()
print("All ok!")

View File

@ -0,0 +1,39 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from neo4j import GraphDatabase, basic_auth
driver = GraphDatabase.driver("bolt://localhost:7687",
auth=basic_auth("", ""),
encrypted=False)
query_template = 'CREATE (n {name:"%s"})'
template_size = len(query_template) - 2 # because of %s
min_len = 1
max_len = 1000000
# binary search because we have to find the maximum size (in number of chars)
# of a query that can be executed via driver
while True:
assert min_len > 0 and max_len > 0, \
"The lengths have to be positive values! If this happens something" \
" is terrible wrong with min & max lengths OR the database" \
" isn't available."
property_size = (max_len + min_len) // 2
try:
driver.session().run(query_template % ("a" * property_size)).consume()
if min_len == max_len or property_size + 1 > max_len:
break
min_len = property_size + 1
except Exception as e:
print("Query size %s is too big!" % (template_size + property_size))
max_len = property_size - 1
assert property_size == max_len, "max_len probably has to be increased!"
print("\nThe max length of a query from Python driver is: %s\n" %
(template_size + property_size))
# sessions are not closed bacause all sessions that are
# executed with wrong query size might be broken
driver.close()

View File

@ -0,0 +1,33 @@
#!/bin/bash -e
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$DIR"
# system check
if ! which virtualenv >/dev/null; then
echo "Please install virtualenv!"
exit 1
fi
# setup virtual environment
if [ ! -d "ve3" ]; then
# Driver downloaded from: https://pypi.org/project/neo4j-driver/1.5.3/
wget -nv https://files.pythonhosted.org/packages/cf/7d/32204b1c2d6f9f9d729bbf8273515c2b3ef42c0b723617d319f3e435f69e/neo4j-driver-4.1.1.tar.gz || exit 1
tar -xzf neo4j-driver-4.1.1.tar.gz || exit 1
mv neo4j-driver-4.1.1 neo4j-driver || exit 1
virtualenv -p python3 ve3 || exit 1
source ve3/bin/activate
cd neo4j-driver
python3 setup.py install || exit 1
cd ..
deactivate
rm -rf neo4j-driver neo4j-driver-4.1.1.tar.gz || exit 1
fi
# activate virtualenv
source ve3/bin/activate
# execute test
python3 basic.py || exit 1
python3 max_query_length.py || exit 1
python3 transactions.py || exit 1

View File

@ -0,0 +1,55 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from neo4j import GraphDatabase, basic_auth
from neo4j.exceptions import ClientError, TransientError
def tx_error(tx, name, name2):
a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name).value()
print(a[0])
tx.run("CREATE (").consume()
a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name2).value()
print(a[0])
def tx_good(tx, name, name2):
a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name).value()
print(a[0])
a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name2).value()
print(a[0])
def tx_too_long(tx):
tx.run("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*) AS cnt")
with GraphDatabase.driver("bolt://localhost:7687", auth=basic_auth("", ""),
encrypted=False) as driver:
def add_person(f, name, name2):
with driver.session() as session:
session.write_transaction(f, name, name2)
# Wrong query.
try:
add_person(tx_error, "mirko", "slavko")
except ClientError:
pass
# Correct query.
add_person(tx_good, "mirka", "slavka")
# Setup for next query.
with driver.session() as session:
session.run("UNWIND range(1, 100000) AS x CREATE ()").consume()
# Query that will run for a very long time, transient error expected.
timed_out = False
try:
with driver.session() as session:
session.run("MATCH (a), (b), (c), (d), (e), (f) RETURN COUNT(*) AS cnt").consume()
except TransientError:
timed_out = True
if timed_out:
print("The query timed out as was expected.")
else:
raise Exception("The query should have timed out, but it didn't!")
print("All ok!")

View File

@ -29,6 +29,8 @@ $binary_dir/memgraph \
--data-directory=$tmpdir \
--query-execution-timeout-sec=5 \
--bolt-session-inactivity-timeout=10 \
--bolt-cert-file="" \
--bolt-server-name-for-init="Neo4j/1.1" \
--min-log-level 1 &
pid=$!
wait_for_server 7687
@ -39,12 +41,19 @@ for i in *; do
if [ ! -d $i ]; then continue; fi
pushd $i
echo "Running: $i"
./run.sh
code_test=$?
if [ $code_test -ne 0 ]; then
echo "FAILED: $i"
break
fi
# run all versions
for v in *; do
if [ ! -d $v ]; then continue; fi
pushd $v
echo "Running version: $v"
./run.sh
code_test=$?
if [ $code_test -ne 0 ]; then
echo "FAILED: $i"
break
fi
popd
done;
echo
popd
done

View File

@ -60,7 +60,7 @@ class BoltSession final
using communication::bolt::Session<communication::InputStream,
communication::OutputStream>::TEncoder;
std::vector<std::string> Interpret(
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query,
const std::map<std::string, communication::bolt::Value> &params)
override {
@ -70,7 +70,7 @@ class BoltSession final
auto ret = client_->Execute(query, params);
records_ = std::move(ret.records);
metadata_ = std::move(ret.metadata);
return ret.fields;
return {ret.fields, {}};
} catch (const communication::bolt::ClientQueryException &e) {
// Wrap query exceptions in a client error to indicate to the client that
// it should fix the query and try again.
@ -84,14 +84,23 @@ class BoltSession final
}
}
std::map<std::string, communication::bolt::Value> PullAll(
TEncoder *encoder) override {
std::map<std::string, communication::bolt::Value> Pull(
TEncoder *encoder, std::optional<int>, std::optional<int>) override {
for (const auto &record : records_) {
encoder->MessageRecord(record);
}
return metadata_;
}
std::map<std::string, communication::bolt::Value> Discard(
std::optional<int>, std::optional<int>) override {
return {};
}
void BeginTransaction() override {}
void CommitTransaction() override {}
void RollbackTransaction() override {}
void Abort() override {
// Called only for cleanup.
records_.clear();

View File

@ -16,7 +16,7 @@ int main(int argc, char *argv[]) {
query::Interpreter interpreter{&interpreter_context};
ResultStreamFaker stream(&db);
auto [header, _] = interpreter.Prepare(argv[1], {});
auto [header, _, qid] = interpreter.Prepare(argv[1], {});
stream.Header(header);
auto summary = interpreter.PullAll(&stream);
stream.Summary(summary);

View File

@ -1,13 +1,14 @@
#include "bolt_common.hpp"
#include "communication/buffer.hpp"
#include "communication/bolt/v1/decoder/chunked_decoder_buffer.hpp"
#include "communication/buffer.hpp"
constexpr const int SIZE = 131072;
uint8_t data[SIZE];
using BufferT = communication::Buffer;
using StreamBufferT = io::network::StreamBuffer;
using DecoderBufferT = communication::bolt::ChunkedDecoderBuffer<BufferT::ReadEnd>;
using DecoderBufferT =
communication::bolt::ChunkedDecoderBuffer<BufferT::ReadEnd>;
using ChunkStateT = communication::bolt::ChunkState;
TEST(BoltBuffer, CorrectChunk) {

View File

@ -3,6 +3,7 @@
#include "bolt_common.hpp"
#include "communication/bolt/v1/session.hpp"
#include "communication/exceptions.hpp"
using communication::bolt::ClientError;
using communication::bolt::Session;
@ -12,6 +13,7 @@ using communication::bolt::Value;
static const char *kInvalidQuery = "invalid query";
static const char *kQueryReturn42 = "RETURN 42";
static const char *kQueryReturnMultiple = "UNWIND [1,2,3] as n RETURN n";
static const char *kQueryEmpty = "no results";
class TestSessionData {};
@ -25,29 +27,58 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
: Session<TestInputStream, TestOutputStream>(input_stream,
output_stream) {}
std::vector<std::string> Interpret(
std::pair<std::vector<std::string>, std::optional<int>> Interpret(
const std::string &query,
const std::map<std::string, Value> &params) override {
if (query == kQueryReturn42 || query == kQueryEmpty) {
if (query == kQueryReturn42 || query == kQueryEmpty ||
query == kQueryReturnMultiple) {
query_ = query;
return {"result_name"};
return {{"result_name"}, {}};
} else {
query_ = "";
throw ClientError("client sent invalid query");
}
}
std::map<std::string, Value> PullAll(TEncoder *encoder) override {
std::map<std::string, Value> Pull(TEncoder *encoder, std::optional<int> n,
std::optional<int> qid) override {
if (query_ == kQueryReturn42) {
encoder->MessageRecord(std::vector<Value>{Value(42)});
return {};
} else if (query_ == kQueryEmpty) {
return {};
} else if (query_ == kQueryReturnMultiple) {
static const std::array elements{1, 2, 3};
static size_t global_counter = 0;
int local_counter = 0;
for (; global_counter < elements.size() && (!n || local_counter < *n);
++global_counter) {
encoder->MessageRecord(
std::vector<Value>{Value(elements[global_counter])});
++local_counter;
}
if (global_counter == elements.size()) {
global_counter = 0;
return {std::pair("has_more", false)};
}
return {std::pair("has_more", true)};
} else {
throw ClientError("client sent invalid query");
}
}
std::map<std::string, Value> Discard(std::optional<int>,
std::optional<int>) override {
return {};
}
void BeginTransaction() override {}
void CommitTransaction() override {}
void RollbackTransaction() override {}
void Abort() override {}
bool Authenticate(const std::string &username,
@ -73,25 +104,65 @@ class TestSession : public Session<TestInputStream, TestOutputStream> {
std::vector<uint8_t> &output = output_stream.output;
// Sample testdata that has correct inputs and outputs.
const uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t handshake_resp[] = {0x00, 0x00, 0x00, 0x01};
const uint8_t init_req[] = {
constexpr uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x00, 0x01};
constexpr uint8_t init_req[] = {
0xb2, 0x01, 0xd0, 0x15, 0x6c, 0x69, 0x62, 0x6e, 0x65, 0x6f, 0x34,
0x6a, 0x2d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x2f, 0x31, 0x2e,
0x32, 0x2e, 0x31, 0xa3, 0x86, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x65,
0x85, 0x62, 0x61, 0x73, 0x69, 0x63, 0x89, 0x70, 0x72, 0x69, 0x6e,
0x63, 0x69, 0x70, 0x61, 0x6c, 0x80, 0x8b, 0x63, 0x72, 0x65, 0x64,
0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x73, 0x80};
const uint8_t init_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00};
const uint8_t run_req_header[] = {0xb2, 0x10, 0xd1};
const uint8_t pullall_req[] = {0xb0, 0x3f};
const uint8_t discardall_req[] = {0xb0, 0x2f};
const uint8_t reset_req[] = {0xb0, 0x0f};
const uint8_t ackfailure_req[] = {0xb0, 0x0e};
const uint8_t success_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00};
const uint8_t ignored_resp[] = {0x00, 0x02, 0xb0, 0x7e, 0x00, 0x00};
constexpr uint8_t init_resp[] = {0x00, 0x18, 0xb1, 0x70, 0xa1, 0x8d, 0x63,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69,
0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x86, 0x62,
0x6f, 0x6c, 0x74, 0x2d, 0x31, 0x00, 0x00};
constexpr uint8_t run_req_header[] = {0xb2, 0x10, 0xd1};
constexpr uint8_t pullall_req[] = {0xb0, 0x3f};
constexpr uint8_t discardall_req[] = {0xb0, 0x2f};
constexpr uint8_t reset_req[] = {0xb0, 0x0f};
constexpr uint8_t ackfailure_req[] = {0xb0, 0x0e};
constexpr uint8_t success_resp[] = {0x00, 0x03, 0xb1, 0x70, 0xa0, 0x00, 0x00};
constexpr uint8_t ignored_resp[] = {0x00, 0x02, 0xb0, 0x7e, 0x00, 0x00};
namespace v4 {
constexpr uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x00, 0x04};
constexpr uint8_t init_req[] = {
0xb1, 0x01, 0xa5, 0x8a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x61, 0x67, 0x65,
0x6e, 0x74, 0xd0, 0x2f, 0x6e, 0x65, 0x6f, 0x34, 0x6a, 0x2d, 0x70, 0x79,
0x74, 0x68, 0x6f, 0x6e, 0x2f, 0x34, 0x2e, 0x31, 0x2e, 0x31, 0x20, 0x50,
0x79, 0x74, 0x68, 0x6f, 0x6e, 0x2f, 0x33, 0x2e, 0x37, 0x2e, 0x33, 0x2d,
0x66, 0x69, 0x6e, 0x61, 0x6c, 0x2d, 0x30, 0x20, 0x28, 0x6c, 0x69, 0x6e,
0x75, 0x78, 0x29, 0x86, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x65, 0x85, 0x62,
0x61, 0x73, 0x69, 0x63, 0x89, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70,
0x61, 0x6c, 0x80, 0x8b, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69,
0x61, 0x6c, 0x73, 0x80, 0x87, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67,
0xa1, 0x87, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x8e, 0x6c, 0x6f,
0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, 0x3a, 0x37, 0x36, 0x38, 0x37};
constexpr uint8_t init_resp[] = {0x00, 0x18, 0xb1, 0x70, 0xa1, 0x8d, 0x63,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69,
0x6f, 0x6e, 0x5f, 0x69, 0x64, 0x86, 0x62,
0x6f, 0x6c, 0x74, 0x2d, 0x31, 0x00, 0x00};
constexpr uint8_t run_req_header[] = {0xb3, 0x10, 0xd1};
constexpr uint8_t pullall_req[] = {0xb1, 0x3f, 0xa0};
constexpr uint8_t pull_one_req[] = {0xb1, 0x3f, 0xa1, 0x81, 0x6e, 0x01};
constexpr uint8_t reset_req[] = {0xb0, 0x0f};
constexpr uint8_t goodbye[] = {0xb0, 0x02};
} // namespace v4
namespace v4_1 {
constexpr uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x01,
0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
constexpr uint8_t handshake_resp[] = {0x00, 0x00, 0x01, 0x04};
constexpr uint8_t noop[] = {0x00, 0x00};
} // namespace v4_1
// Write bolt chunk header (length)
void WriteChunkHeader(TestInputStream &input_stream, uint16_t len) {
@ -135,12 +206,14 @@ void CheckIgnoreMessage(std::vector<uint8_t> &output) {
// Execute and check a correct handshake
void ExecuteHandshake(TestInputStream &input_stream, TestSession &session,
std::vector<uint8_t> &output) {
input_stream.Write(handshake_req, 20);
std::vector<uint8_t> &output,
const uint8_t *request = handshake_req,
const uint8_t *expected_resp = handshake_resp) {
input_stream.Write(request, 20);
session.Execute();
ASSERT_EQ(session.state_, State::Init);
PrintOutput(output);
CheckOutput(output, handshake_resp, 4);
CheckOutput(output, expected_resp, 4);
}
// Write bolt chunk and execute command
@ -154,21 +227,28 @@ void ExecuteCommand(TestInputStream &input_stream, TestSession &session,
// Execute and check a correct init
void ExecuteInit(TestInputStream &input_stream, TestSession &session,
std::vector<uint8_t> &output) {
ExecuteCommand(input_stream, session, init_req, sizeof(init_req));
std::vector<uint8_t> &output, const bool is_v4 = false) {
const auto *request = is_v4 ? v4::init_req : init_req;
const auto request_size = is_v4 ? sizeof(v4::init_req) : sizeof(init_req);
ExecuteCommand(input_stream, session, request, request_size);
ASSERT_EQ(session.state_, State::Idle);
PrintOutput(output);
CheckOutput(output, init_resp, 7);
const auto *response = is_v4 ? v4::init_resp : init_resp;
CheckOutput(output, response, 28);
}
// Write bolt encoded run request
void WriteRunRequest(TestInputStream &input_stream, const char *str) {
void WriteRunRequest(TestInputStream &input_stream, const char *str,
const bool is_v4 = false) {
// write chunk header
auto len = strlen(str);
WriteChunkHeader(input_stream, 3 + 2 + len + 1);
WriteChunkHeader(input_stream, (3 + is_v4) + 2 + len + 1);
const auto *run_header = is_v4 ? v4::run_req_header : run_req_header;
const auto run_header_size =
is_v4 ? sizeof(v4::run_req_header) : sizeof(run_req_header);
// write string header
input_stream.Write(run_req_header, 3);
input_stream.Write(run_header, run_header_size);
// write string length
WriteChunkHeader(input_stream, len);
@ -179,6 +259,11 @@ void WriteRunRequest(TestInputStream &input_stream, const char *str) {
// write empty map for parameters
input_stream.Write("\xA0", 1); // TinyMap0
if (is_v4) {
// write empty map for extra field
input_stream.Write("\xA0", 1); // TinyMap
}
// write chunk tail
WriteChunkTail(input_stream);
}
@ -227,6 +312,46 @@ TEST(BoltSession, HandshakeOK) {
ExecuteHandshake(input_stream, session, output);
}
TEST(BoltSession, HandshakeMultiVersionRequest) {
// Should pick the first version, 4.0, even though a higher version is present
// but with a lower priority
{
INIT_VARS;
const uint8_t priority_request[] = {
0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00,
0x01, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t priority_response[] = {0x00, 0x00, 0x00, 0x04};
ExecuteHandshake(input_stream, session, output, priority_request,
priority_response);
ASSERT_EQ(session.version_.minor, 0);
ASSERT_EQ(session.version_.major, 4);
}
// Should pick the second version, 4.1, because first, 3.0, is not supported
{
INIT_VARS;
const uint8_t unsupported_first_request[] = {
0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00,
0x01, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
const uint8_t unsupported_first_response[] = {0x00, 0x00, 0x01, 0x04};
ExecuteHandshake(input_stream, session, output, unsupported_first_request,
unsupported_first_response);
ASSERT_EQ(session.version_.minor, 1);
ASSERT_EQ(session.version_.major, 4);
}
// No supported version present in the request
{
INIT_VARS;
const uint8_t no_supported_versions_request[] = {
0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00,
0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
ASSERT_THROW(ExecuteHandshake(input_stream, session, output,
no_supported_versions_request),
SessionException);
}
}
TEST(BoltSession, InitWrongSignature) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
@ -280,9 +405,17 @@ TEST(BoltSession, InitWriteFail) {
}
TEST(BoltSession, InitOK) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
}
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req,
v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
}
}
TEST(BoltSession, ExecuteRunWrongMarker) {
@ -344,15 +477,32 @@ TEST(BoltSession, ExecuteRunBasicException) {
}
TEST(BoltSession, ExecuteRunWithoutPullAll) {
INIT_VARS;
// v1
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
WriteRunRequest(input_stream, kQueryReturn42);
session.Execute();
WriteRunRequest(input_stream, kQueryReturn42);
session.Execute();
ASSERT_EQ(session.state_, State::Result);
ASSERT_EQ(session.state_, State::Result);
}
// v4+
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req,
v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
WriteRunRequest(input_stream, kQueryReturn42, true);
session.Execute();
ASSERT_EQ(session.state_, State::Result);
}
}
TEST(BoltSession, ExecutePullAllDiscardAllResetWrongMarker) {
@ -401,35 +551,37 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) {
TEST(BoltSession, ExecutePullAllDiscardAllReset) {
// This test first tests PULL_ALL then DISCARD_ALL and then RESET
// It tests a good message
const uint8_t *dataset[3] = {pullall_req, discardall_req, reset_req};
{
const uint8_t *dataset[3] = {pullall_req, discardall_req, reset_req};
for (int i = 0; i < 3; ++i) {
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
INIT_VARS;
for (int i = 0; i < 3; ++i) {
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
WriteRunRequest(input_stream, kQueryReturn42);
session.Execute();
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
WriteRunRequest(input_stream, kQueryReturn42);
session.Execute();
if (j == 1) output.clear();
if (j == 1) output.clear();
output_stream.SetWriteSuccess(j == 0);
if (j == 0) {
ExecuteCommand(input_stream, session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2),
SessionException);
}
output_stream.SetWriteSuccess(j == 0);
if (j == 0) {
ExecuteCommand(input_stream, session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2),
SessionException);
}
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.encoder_buffer_.HasData());
PrintOutput(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_EQ(output.size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
ASSERT_FALSE(session.encoder_buffer_.HasData());
PrintOutput(output);
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_EQ(output.size(), 0);
}
}
}
}
@ -548,39 +700,69 @@ TEST(BoltSession, ErrorWrongMarker) {
}
TEST(BoltSession, ErrorOK) {
// test ACK_FAILURE and RESET
const uint8_t *dataset[] = {ackfailure_req, reset_req};
// v1
{
// test ACK_FAILURE and RESET
const uint8_t *dataset[] = {ackfailure_req, reset_req};
for (int i = 0; i < 2; ++i) {
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
for (int i = 0; i < 2; ++i) {
// first test with socket write success, then with socket write fail
for (int j = 0; j < 2; ++j) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
WriteRunRequest(input_stream, kInvalidQuery);
session.Execute();
output.clear();
output_stream.SetWriteSuccess(j == 0);
if (j == 0) {
ExecuteCommand(input_stream, session, dataset[i], 2);
} else {
ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2),
SessionException);
}
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_EQ(output.size(), 0);
}
}
}
}
// v4+
{
const uint8_t *dataset[] = {ackfailure_req, v4::reset_req};
for (int i = 0; i < 2; ++i) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ExecuteHandshake(input_stream, session, output, v4::handshake_req,
v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
WriteRunRequest(input_stream, kInvalidQuery);
WriteRunRequest(input_stream, kInvalidQuery, true);
session.Execute();
output.clear();
output_stream.SetWriteSuccess(j == 0);
if (j == 0) {
ExecuteCommand(input_stream, session, dataset[i], 2);
ExecuteCommand(input_stream, session, dataset[i], 2);
// ACK_FAILURE does not exist in v4+
if (i == 0) {
ASSERT_EQ(session.state_, State::Error);
} else {
ASSERT_THROW(ExecuteCommand(input_stream, session, dataset[i], 2),
SessionException);
}
// assert that all data from the init message was cleaned up
ASSERT_EQ(session.decoder_buffer_.Size(), 0);
if (j == 0) {
ASSERT_EQ(session.state_, State::Idle);
CheckOutput(output, success_resp, sizeof(success_resp));
} else {
ASSERT_EQ(session.state_, State::Close);
ASSERT_EQ(output.size(), 0);
}
}
}
@ -633,6 +815,51 @@ TEST(BoltSession, MultipleChunksInOneExecute) {
ASSERT_EQ(num, 3);
}
TEST(BoltSession, PartialPull) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req,
v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
WriteRunRequest(input_stream, kQueryReturnMultiple, true);
ExecuteCommand(input_stream, session, v4::pull_one_req,
sizeof(v4::pull_one_req));
// Not all results were pulled
ASSERT_EQ(session.state_, State::Result);
PrintOutput(output);
int len, num = 0;
while (output.size() > 0) {
len = (output[0] << 8) + output[1];
output.erase(output.begin(), output.begin() + len + 4);
++num;
}
// the first is a success with the query headers
// the second is a record message
// and the last is a success message with query run metadata
ASSERT_EQ(num, 3);
ExecuteCommand(input_stream, session, v4::pullall_req,
sizeof(v4::pullall_req));
ASSERT_EQ(session.state_, State::Idle);
PrintOutput(output);
len = 0;
num = 0;
while (output.size() > 0) {
len = (output[0] << 8) + output[1];
output.erase(output.begin(), output.begin() + len + 4);
++num;
}
// First two are the record messages
// and the last is a success message with query run metadata
ASSERT_EQ(num, 3);
}
TEST(BoltSession, PartialChunk) {
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
@ -656,6 +883,86 @@ TEST(BoltSession, PartialChunk) {
PrintOutput(output);
}
TEST(BoltSession, Goodbye) {
// v4 supports goodbye message
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4::handshake_req,
v4::handshake_resp);
ExecuteInit(input_stream, session, output, true);
ASSERT_THROW(
ExecuteCommand(input_stream, session, v4::goodbye, sizeof(v4::goodbye)),
communication::SessionClosedException);
}
// v1 does not support goodbye message
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output);
ExecuteInit(input_stream, session, output);
ASSERT_THROW(
ExecuteCommand(input_stream, session, v4::goodbye, sizeof(v4::goodbye)),
SessionException);
}
}
TEST(BoltSession, Noop) {
// v4.1 supports NOOP chunk
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, v4_1::handshake_req,
v4_1::handshake_resp);
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop));
ExecuteInit(input_stream, session, output, true);
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop));
WriteRunRequest(input_stream, kQueryReturn42, true);
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop));
ExecuteCommand(input_stream, session, v4::pullall_req,
sizeof(v4::pullall_req));
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop));
}
// v1 does not support NOOP chunk
{
INIT_VARS;
ExecuteHandshake(input_stream, session, output, handshake_req,
handshake_resp);
ASSERT_THROW(
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop)),
SessionException);
CheckFailureMessage(output);
session.state_ = State::Init;
ExecuteInit(input_stream, session, output);
ASSERT_THROW(
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop)),
SessionException);
CheckFailureMessage(output);
session.state_ = State::Idle;
WriteRunRequest(input_stream, kQueryEmpty);
session.Execute();
CheckSuccessMessage(output);
ASSERT_THROW(
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop)),
SessionException);
CheckFailureMessage(output);
session.state_ = State::Result;
ExecuteCommand(input_stream, session, pullall_req, sizeof(v4::pullall_req));
CheckSuccessMessage(output);
ASSERT_THROW(
ExecuteCommand(input_stream, session, v4_1::noop, sizeof(v4_1::noop)),
SessionException);
}
}
int main(int argc, char **argv) {
google::InitGoogleLogging(argv[0]);
::testing::InitGoogleTest(&argc, argv);

View File

@ -22,13 +22,13 @@ const uint8_t int_encoded[][10] = {
// clang-format on
const uint32_t int_encoded_len[] = {1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3,
3, 3, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9};
3, 3, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9};
const double double_decoded[] = {5.834, 108.199, 43677.9882, 254524.5851};
const uint8_t double_encoded[][10] = {"\xC1\x40\x17\x56\x04\x18\x93\x74\xBC",
"\xC1\x40\x5B\x0C\xBC\x6A\x7E\xF9\xDB",
"\xC1\x40\xE5\x53\xBF\x9F\x55\x9B\x3D",
"\xC1\x41\x0F\x11\xE4\xAE\x48\xE8\xA7"};
"\xC1\x40\x5B\x0C\xBC\x6A\x7E\xF9\xDB",
"\xC1\x40\xE5\x53\xBF\x9F\x55\x9B\x3D",
"\xC1\x41\x0F\x11\xE4\xAE\x48\xE8\xA7"};
const uint8_t vertexedge_encoded[] =
"\xB1\x71\x93\xB3\x4E\x00\x92\x86\x6C\x61\x62\x65\x6C\x31\x86\x6C\x61\x62"
@ -37,8 +37,7 @@ const uint8_t vertexedge_encoded[] =
"\x79\x70\x65\xA2\x85\x70\x72\x6F\x70\x33\x2A\x85\x70\x72\x6F\x70\x34\xC9"
"\x04\xD2";
const uint64_t sizes[] = {0, 1, 5, 15, 16, 120,
255, 256, 12345, 65535, 65536};
const uint64_t sizes[] = {0, 1, 5, 15, 16, 120, 255, 256, 12345, 65535, 65536};
const uint64_t sizes_num = 11;
constexpr const int STRING = 0, LIST = 1, MAP = 2;

View File

@ -7,8 +7,10 @@
#include "gtest/gtest.h"
#include "query/exceptions.hpp"
#include "query/interpreter.hpp"
#include "query/stream.hpp"
#include "query/typed_value.hpp"
#include "query_common.hpp"
#include "storage/v2/property_value.hpp"
namespace {
@ -31,6 +33,22 @@ class InterpreterTest : public ::testing::Test {
query::InterpreterContext interpreter_context_{&db_};
query::Interpreter interpreter_{&interpreter_context_};
auto Prepare(
const std::string &query,
const std::map<std::string, storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(&db_);
const auto [header, _, qid] = interpreter_.Prepare(query, params);
stream.Header(header);
return std::pair{std::move(stream), qid};
}
void Pull(ResultStreamFaker *stream, std::optional<int> n = {},
std::optional<int> qid = {}) {
const auto summary = interpreter_.Pull(stream, n, qid);
stream->Summary(summary);
}
/**
* Execute the given query and commit the transaction.
*
@ -39,17 +57,41 @@ class InterpreterTest : public ::testing::Test {
auto Interpret(
const std::string &query,
const std::map<std::string, storage::PropertyValue> &params = {}) {
ResultStreamFaker stream(&db_);
auto prepare_result = Prepare(query, params);
auto [header, _] = interpreter_.Prepare(query, params);
stream.Header(header);
auto summary = interpreter_.PullAll(&stream);
auto &stream = prepare_result.first;
auto summary = interpreter_.Pull(&stream, {}, prepare_result.second);
stream.Summary(summary);
return stream;
return std::move(stream);
}
};
TEST_F(InterpreterTest, MultiplePulls) {
{
auto [stream, qid] = Prepare("UNWIND [1,2,3,4,5] as n RETURN n");
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "n");
Pull(&stream, 1);
ASSERT_EQ(stream.GetSummary().count("has_more"), 1);
ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 1);
Pull(&stream, 2);
ASSERT_EQ(stream.GetSummary().count("has_more"), 1);
ASSERT_TRUE(stream.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream.GetResults().size(), 3U);
ASSERT_EQ(stream.GetResults()[1][0].ValueInt(), 2);
ASSERT_EQ(stream.GetResults()[2][0].ValueInt(), 3);
Pull(&stream);
ASSERT_EQ(stream.GetSummary().count("has_more"), 1);
ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream.GetResults().size(), 5U);
ASSERT_EQ(stream.GetResults()[3][0].ValueInt(), 4);
ASSERT_EQ(stream.GetResults()[4][0].ValueInt(), 5);
}
}
// Run query with different ast twice to see if query executes correctly when
// ast is read from cache.
TEST_F(InterpreterTest, AstCache) {
@ -481,6 +523,40 @@ TEST_F(InterpreterTest, ExplainQuery) {
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
}
TEST_F(InterpreterTest, ExplainQueryMultiplePulls) {
EXPECT_EQ(interpreter_context_.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context_.ast_cache.size(), 0U);
auto [stream, qid] = Prepare("EXPLAIN MATCH (n) RETURN *;");
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader().front(), "QUERY PLAN");
std::vector<std::string> expected_rows{" * Produce {n}", " * ScanAll (n)",
" * Once"};
Pull(&stream, 1);
ASSERT_EQ(stream.GetResults().size(), 1);
auto expected_it = expected_rows.begin();
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
EXPECT_EQ(stream.GetResults()[0].front().ValueString(), *expected_it);
++expected_it;
Pull(&stream, 1);
ASSERT_EQ(stream.GetResults().size(), 2);
ASSERT_EQ(stream.GetResults()[1].size(), 1U);
EXPECT_EQ(stream.GetResults()[1].front().ValueString(), *expected_it);
++expected_it;
Pull(&stream);
ASSERT_EQ(stream.GetResults().size(), 3);
ASSERT_EQ(stream.GetResults()[2].size(), 1U);
EXPECT_EQ(stream.GetResults()[2].front().ValueString(), *expected_it);
// We should have a plan cache for MATCH ...
EXPECT_EQ(interpreter_context_.plan_cache.size(), 1U);
// We should have AST cache for EXPLAIN ... and for inner MATCH ...
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(interpreter_context_.plan_cache.size(), 1U);
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
}
TEST_F(InterpreterTest, ExplainQueryInMulticommandTransaction) {
EXPECT_EQ(interpreter_context_.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context_.ast_cache.size(), 0U);
@ -557,6 +633,43 @@ TEST_F(InterpreterTest, ProfileQuery) {
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
}
TEST_F(InterpreterTest, ProfileQueryMultiplePulls) {
EXPECT_EQ(interpreter_context_.plan_cache.size(), 0U);
EXPECT_EQ(interpreter_context_.ast_cache.size(), 0U);
auto [stream, qid] = Prepare("PROFILE MATCH (n) RETURN *;");
std::vector<std::string> expected_header{"OPERATOR", "ACTUAL HITS",
"RELATIVE TIME", "ABSOLUTE TIME"};
EXPECT_EQ(stream.GetHeader(), expected_header);
std::vector<std::string> expected_rows{"* Produce", "* ScanAll", "* Once"};
auto expected_it = expected_rows.begin();
Pull(&stream, 1);
ASSERT_EQ(stream.GetResults().size(), 1U);
ASSERT_EQ(stream.GetResults()[0].size(), 4U);
ASSERT_EQ(stream.GetResults()[0][0].ValueString(), *expected_it);
++expected_it;
Pull(&stream, 1);
ASSERT_EQ(stream.GetResults().size(), 2U);
ASSERT_EQ(stream.GetResults()[1].size(), 4U);
ASSERT_EQ(stream.GetResults()[1][0].ValueString(), *expected_it);
++expected_it;
Pull(&stream);
ASSERT_EQ(stream.GetResults().size(), 3U);
ASSERT_EQ(stream.GetResults()[2].size(), 4U);
ASSERT_EQ(stream.GetResults()[2][0].ValueString(), *expected_it);
// We should have a plan cache for MATCH ...
EXPECT_EQ(interpreter_context_.plan_cache.size(), 1U);
// We should have AST cache for PROFILE ... and for inner MATCH ...
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
Interpret("MATCH (n) RETURN *;");
EXPECT_EQ(interpreter_context_.plan_cache.size(), 1U);
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
}
TEST_F(InterpreterTest, ProfileQueryInMulticommandTransaction) {
Interpret("BEGIN");
ASSERT_THROW(Interpret("PROFILE MATCH (n) RETURN *;"),
@ -615,3 +728,104 @@ TEST_F(InterpreterTest, ProfileQueryWithLiterals) {
EXPECT_EQ(interpreter_context_.plan_cache.size(), 1U);
EXPECT_EQ(interpreter_context_.ast_cache.size(), 2U);
}
TEST_F(InterpreterTest, Transactions) {
{
ASSERT_THROW(interpreter_.CommitTransaction(),
query::ExplicitTransactionUsageException);
ASSERT_THROW(interpreter_.RollbackTransaction(),
query::ExplicitTransactionUsageException);
interpreter_.BeginTransaction();
ASSERT_THROW(interpreter_.BeginTransaction(),
query::ExplicitTransactionUsageException);
auto [stream, qid] = Prepare("RETURN 2");
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "2");
Pull(&stream, 1);
ASSERT_EQ(stream.GetSummary().count("has_more"), 1);
ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 2);
interpreter_.CommitTransaction();
}
{
interpreter_.BeginTransaction();
auto [stream, qid] = Prepare("RETURN 2");
ASSERT_EQ(stream.GetHeader().size(), 1U);
EXPECT_EQ(stream.GetHeader()[0], "2");
Pull(&stream, 1);
ASSERT_EQ(stream.GetSummary().count("has_more"), 1);
ASSERT_FALSE(stream.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream.GetResults()[0].size(), 1U);
ASSERT_EQ(stream.GetResults()[0][0].ValueInt(), 2);
interpreter_.RollbackTransaction();
}
}
TEST_F(InterpreterTest, Qid) {
{
interpreter_.BeginTransaction();
auto [stream, qid] = Prepare("RETURN 2");
ASSERT_TRUE(qid);
ASSERT_THROW(Pull(&stream, {}, *qid + 1), query::InvalidArgumentsException);
interpreter_.RollbackTransaction();
}
{
interpreter_.BeginTransaction();
auto [stream1, qid1] = Prepare("UNWIND(range(1,3)) as n RETURN n");
ASSERT_TRUE(qid1);
ASSERT_EQ(stream1.GetHeader().size(), 1U);
EXPECT_EQ(stream1.GetHeader()[0], "n");
auto [stream2, qid2] = Prepare("UNWIND(range(4,6)) as n RETURN n");
ASSERT_TRUE(qid2);
ASSERT_EQ(stream2.GetHeader().size(), 1U);
EXPECT_EQ(stream2.GetHeader()[0], "n");
Pull(&stream1, 1, qid1);
ASSERT_EQ(stream1.GetSummary().count("has_more"), 1);
ASSERT_TRUE(stream1.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream1.GetResults().size(), 1U);
ASSERT_EQ(stream1.GetResults()[0].size(), 1U);
ASSERT_EQ(stream1.GetResults()[0][0].ValueInt(), 1);
auto [stream3, qid3] = Prepare("UNWIND(range(7,9)) as n RETURN n");
ASSERT_TRUE(qid3);
ASSERT_EQ(stream3.GetHeader().size(), 1U);
EXPECT_EQ(stream3.GetHeader()[0], "n");
Pull(&stream2, {}, qid2);
ASSERT_EQ(stream2.GetSummary().count("has_more"), 1);
ASSERT_FALSE(stream2.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream2.GetResults().size(), 3U);
ASSERT_EQ(stream2.GetResults()[0].size(), 1U);
ASSERT_EQ(stream2.GetResults()[0][0].ValueInt(), 4);
ASSERT_EQ(stream2.GetResults()[1][0].ValueInt(), 5);
ASSERT_EQ(stream2.GetResults()[2][0].ValueInt(), 6);
Pull(&stream3, 1, qid3);
ASSERT_EQ(stream3.GetSummary().count("has_more"), 1);
ASSERT_TRUE(stream3.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream3.GetResults().size(), 1U);
ASSERT_EQ(stream3.GetResults()[0].size(), 1U);
ASSERT_EQ(stream3.GetResults()[0][0].ValueInt(), 7);
Pull(&stream1, {}, qid1);
ASSERT_EQ(stream1.GetSummary().count("has_more"), 1);
ASSERT_FALSE(stream1.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream1.GetResults().size(), 3U);
ASSERT_EQ(stream1.GetResults()[1].size(), 1U);
ASSERT_EQ(stream1.GetResults()[1][0].ValueInt(), 2);
ASSERT_EQ(stream1.GetResults()[2][0].ValueInt(), 3);
Pull(&stream3);
ASSERT_EQ(stream3.GetSummary().count("has_more"), 1);
ASSERT_FALSE(stream3.GetSummary().at("has_more").ValueBool());
ASSERT_EQ(stream3.GetResults().size(), 3U);
ASSERT_EQ(stream3.GetResults()[1].size(), 1U);
ASSERT_EQ(stream3.GetResults()[1][0].ValueInt(), 8);
ASSERT_EQ(stream3.GetResults()[2][0].ValueInt(), 9);
interpreter_.CommitTransaction();
}
}

View File

@ -215,7 +215,7 @@ auto Execute(storage::Storage *db, const std::string &query) {
query::Interpreter interpreter(&context);
ResultStreamFaker stream(db);
auto [header, _] = interpreter.Prepare(query, {});
auto [header, _, qid] = interpreter.Prepare(query, {});
stream.Header(header);
auto summary = interpreter.PullAll(&stream);
stream.Summary(summary);
@ -792,7 +792,7 @@ class StatefulInterpreter {
auto Execute(const std::string &query) {
ResultStreamFaker stream(db_);
auto [header, _] = interpreter_.Prepare(query, {});
auto [header, _, qid] = interpreter_.Prepare(query, {});
stream.Header(header);
auto summary = interpreter_.PullAll(&stream);
stream.Summary(summary);
@ -872,3 +872,119 @@ TEST(DumpTest, ExecuteDumpDatabaseInMulticommandTransaction) {
// Rollback the transaction.
interpreter.Execute("ROLLBACK");
}
// NOLINTNEXTLINE(hicpp-special-member-functions)
TEST(DumpTest, MultiplePartialPulls) {
storage::Storage db;
{
// Create indices
db.CreateIndex(db.NameToLabel("PERSON"), db.NameToProperty("name"));
db.CreateIndex(db.NameToLabel("PERSON"), db.NameToProperty("surname"));
// Create existence constraints
{
auto res = db.CreateExistenceConstraint(db.NameToLabel("PERSON"),
db.NameToProperty("name"));
ASSERT_TRUE(res.HasValue());
ASSERT_TRUE(res.GetValue());
}
{
auto res = db.CreateExistenceConstraint(db.NameToLabel("PERSON"),
db.NameToProperty("surname"));
ASSERT_TRUE(res.HasValue());
ASSERT_TRUE(res.GetValue());
}
// Create unique constraints
{
auto res = db.CreateUniqueConstraint(db.NameToLabel("PERSON"),
{db.NameToProperty("name")});
ASSERT_TRUE(res.HasValue());
ASSERT_EQ(res.GetValue(),
storage::UniqueConstraints::CreationStatus::SUCCESS);
}
{
auto res = db.CreateUniqueConstraint(db.NameToLabel("PERSON"),
{db.NameToProperty("surname")});
ASSERT_TRUE(res.HasValue());
ASSERT_EQ(res.GetValue(),
storage::UniqueConstraints::CreationStatus::SUCCESS);
}
auto dba = db.Access();
auto p1 = CreateVertex(&dba, {"PERSON"},
{{"name", storage::PropertyValue("Person1")},
{"surname", storage::PropertyValue("Unique1")}},
false);
auto p2 = CreateVertex(&dba, {"PERSON"},
{{"name", storage::PropertyValue("Person2")},
{"surname", storage::PropertyValue("Unique2")}},
false);
auto p3 = CreateVertex(&dba, {"PERSON"},
{{"name", storage::PropertyValue("Person3")},
{"surname", storage::PropertyValue("Unique3")}},
false);
auto p4 = CreateVertex(&dba, {"PERSON"},
{{"name", storage::PropertyValue("Person4")},
{"surname", storage::PropertyValue("Unique4")}},
false);
auto p5 = CreateVertex(&dba, {"PERSON"},
{{"name", storage::PropertyValue("Person5")},
{"surname", storage::PropertyValue("Unique5")}},
false);
CreateEdge(&dba, &p1, &p2, "REL", {}, false);
CreateEdge(&dba, &p1, &p3, "REL", {}, false);
CreateEdge(&dba, &p4, &p5, "REL", {}, false);
CreateEdge(&dba, &p2, &p5, "REL", {}, false);
ASSERT_FALSE(dba.Commit().HasError());
}
ResultStreamFaker stream(&db);
query::AnyStream query_stream(&stream, utils::NewDeleteResource());
auto acc = db.Access();
query::DbAccessor dba(&acc);
query::PullPlanDump pullPlan{&dba};
auto check_next = [&, offset_index =
0U](const std::string &expected_row) mutable {
pullPlan.Pull(&query_stream, 1);
const auto &results{stream.GetResults()};
ASSERT_EQ(results.size(), offset_index + 1);
VerifyQueries({results.begin() + offset_index, results.end()},
expected_row);
++offset_index;
};
check_next("CREATE INDEX ON :`PERSON`(`name`);");
check_next("CREATE INDEX ON :`PERSON`(`surname`);");
check_next("CREATE CONSTRAINT ON (u:`PERSON`) ASSERT EXISTS (u.`name`);");
check_next("CREATE CONSTRAINT ON (u:`PERSON`) ASSERT EXISTS (u.`surname`);");
check_next("CREATE CONSTRAINT ON (u:`PERSON`) ASSERT u.`name` IS UNIQUE;");
check_next("CREATE CONSTRAINT ON (u:`PERSON`) ASSERT u.`surname` IS UNIQUE;");
check_next(kCreateInternalIndex);
check_next(
R"r(CREATE (:__mg_vertex__:`PERSON` {__mg_id__: 0, `name`: "Person1", `surname`: "Unique1"});)r");
check_next(
R"r(CREATE (:__mg_vertex__:`PERSON` {__mg_id__: 1, `name`: "Person2", `surname`: "Unique2"});)r");
check_next(
R"r(CREATE (:__mg_vertex__:`PERSON` {__mg_id__: 2, `name`: "Person3", `surname`: "Unique3"});)r");
check_next(
R"r(CREATE (:__mg_vertex__:`PERSON` {__mg_id__: 3, `name`: "Person4", `surname`: "Unique4"});)r");
check_next(
R"r(CREATE (:__mg_vertex__:`PERSON` {__mg_id__: 4, `name`: "Person5", `surname`: "Unique5"});)r");
check_next(
"MATCH (u:__mg_vertex__), (v:__mg_vertex__) WHERE u.__mg_id__ = 0 AND "
"v.__mg_id__ = 1 CREATE (u)-[:`REL`]->(v);");
check_next(
"MATCH (u:__mg_vertex__), (v:__mg_vertex__) WHERE u.__mg_id__ = 0 AND "
"v.__mg_id__ = 2 CREATE (u)-[:`REL`]->(v);");
check_next(
"MATCH (u:__mg_vertex__), (v:__mg_vertex__) WHERE u.__mg_id__ = 1 AND "
"v.__mg_id__ = 4 CREATE (u)-[:`REL`]->(v);");
check_next(
"MATCH (u:__mg_vertex__), (v:__mg_vertex__) WHERE u.__mg_id__ = 3 AND "
"v.__mg_id__ = 4 CREATE (u)-[:`REL`]->(v);");
check_next(kDropInternalIndex);
check_next(kRemoveInternalLabelProperty);
}

View File

@ -39,7 +39,7 @@ class QueryExecution : public testing::Test {
auto Execute(const std::string &query) {
ResultStreamFaker stream(&*db_);
auto [header, _] = interpreter_->Prepare(query, {});
auto [header, _, qid] = interpreter_->Prepare(query, {});
stream.Header(header);
auto summary = interpreter_->PullAll(&stream);
stream.Summary(summary);