From b15eeffd48783f4ef8d9e6a4fee09a61cb6c381c Mon Sep 17 00:00:00 2001 From: Teon Banek <teon.banek@memgraph.io> Date: Tue, 10 Jul 2018 16:18:19 +0200 Subject: [PATCH] Extract communication to static library Summary: Session specifics have been move out of the Bolt `executing` state, and are accessed via pure virtual Session type. Our server is templated on the session and we are setting the concrete type, so there should be no virtual call overhead. Abstract Session is used to indicate the interface, this could have also been templated, but the explicit interface definition makes it clearer. Specific session implementation for running Memgraph is now implemented in memgraph_bolt, which instantiates the concrete session type. This may not be 100% appropriate place, but Memgraph specific session isn't needed anywhere else. Bolt/communication tests now use a dummy session and depend only on communication, which significantly improves test run times. All these changes make the communication a library which doesn't depend on storage nor the database. Only shared connection points, which aren't part of the base communication library are: * glue/conversion -- which converts between storage and bolt types, and * communication/result_stream_faker -- templated, but used in tests and query/repl Depends on D1453 Reviewers: mferencevic, buda, mtomic, msantl Reviewed By: mferencevic, mtomic Subscribers: pullbot Differential Revision: https://phabricator.memgraph.io/D1456 --- src/CMakeLists.txt | 16 +- src/communication/CMakeLists.txt | 44 ++++ .../bolt/v1/encoder/result_stream.hpp | 13 +- src/communication/bolt/v1/session.hpp | 70 ++---- src/communication/bolt/v1/state.hpp | 14 +- src/communication/bolt/v1/states/error.hpp | 13 +- .../bolt/v1/states/executing.hpp | 128 +++-------- .../bolt/v1/states/handshake.hpp | 1 + src/communication/result_stream_faker.hpp | 5 +- src/database/state_delta.cpp | 10 +- src/durability/recovery.cpp | 12 +- src/durability/snapshooter.cpp | 2 +- src/durability/snapshot_encoder.hpp | 4 +- src/{communication => glue}/conversion.cpp | 6 +- src/{communication => glue}/conversion.hpp | 4 +- src/memgraph_bolt.cpp | 177 ++++++++++++++- src/query/interpreter.cpp | 2 +- src/query/interpreter.hpp | 7 +- src/storage/pod_buffer.hpp | 5 +- src/storage/property_value_store.cpp | 7 +- tests/manual/CMakeLists.txt | 6 +- tests/manual/distributed_common.hpp | 2 +- .../snapshot_generation/snapshot_writer.hpp | 6 +- tests/unit/CMakeLists.txt | 50 +++-- tests/unit/bolt_common.hpp | 3 +- tests/unit/bolt_decoder.cpp | 1 - tests/unit/bolt_encoder.cpp | 8 +- tests/unit/bolt_session.cpp | 206 ++++++------------ 28 files changed, 405 insertions(+), 417 deletions(-) create mode 100644 src/communication/CMakeLists.txt rename src/{communication => glue}/conversion.cpp (98%) rename src/{communication => glue}/conversion.hpp (93%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 925191344..0a495230b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,19 +5,10 @@ add_subdirectory(utils) add_subdirectory(integrations) add_subdirectory(io) add_subdirectory(telemetry) +add_subdirectory(communication) # all memgraph src files set(memgraph_src_files - communication/bolt/v1/decoder/decoded_value.cpp - communication/buffer.cpp - communication/client.cpp - communication/context.cpp - communication/conversion.cpp - communication/helpers.cpp - communication/init.cpp - communication/rpc/client.cpp - communication/rpc/protocol.cpp - communication/rpc/server.cpp data_structures/concurrent/skiplist_gc.cpp database/config.cpp database/counters.cpp @@ -49,6 +40,7 @@ set(memgraph_src_files durability/recovery.cpp durability/snapshooter.cpp durability/wal.cpp + glue/conversion.cpp query/common.cpp query/frontend/ast/ast.cpp query/frontend/ast/cypher_main_visitor.cpp @@ -195,10 +187,9 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) # memgraph_lib depend on these libraries set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools antlr_opencypher_parser_lib dl glog gflags capnp kj - ${OPENSSL_LIBRARIES} ${Boost_IOSTREAMS_LIBRARY_RELEASE} ${Boost_SERIALIZATION_LIBRARY_RELEASE} - mg-utils mg-io mg-integrations) + mg-utils mg-io mg-integrations mg-communication) if (USE_LTALLOC) list(APPEND MEMGRAPH_ALL_LIBS ltalloc) @@ -213,7 +204,6 @@ endif() # STATIC library used by memgraph executables add_library(memgraph_lib STATIC ${memgraph_src_files}) target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS}) -target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR}) add_dependencies(memgraph_lib generate_opencypher_parser) add_dependencies(memgraph_lib generate_lcp) add_dependencies(memgraph_lib generate_capnp) diff --git a/src/communication/CMakeLists.txt b/src/communication/CMakeLists.txt new file mode 100644 index 000000000..e41576c4b --- /dev/null +++ b/src/communication/CMakeLists.txt @@ -0,0 +1,44 @@ +set(communication_src_files + bolt/v1/decoder/decoded_value.cpp + buffer.cpp + client.cpp + context.cpp + helpers.cpp + init.cpp + rpc/client.cpp + rpc/protocol.cpp + rpc/server.cpp) + +# TODO: Extract data_structures to library +set(communication_src_files ${communication_src_files} + ${CMAKE_SOURCE_DIR}/src/data_structures/concurrent/skiplist_gc.cpp) + +# Use this function to add each capnp file to generation. This way each file is +# standalone and we avoid recompiling everything. +# NOTE: communication_src_files and communication_capnp_files are globally updated. +# TODO: This is duplicated from src/CMakeLists.txt and +# src/utils/CMakeLists.txt, find a good way to generalize this on per +# subdirectory basis. +function(add_capnp capnp_src_file) + set(cpp_file ${CMAKE_CURRENT_SOURCE_DIR}/${capnp_src_file}.c++) + set(h_file ${CMAKE_CURRENT_SOURCE_DIR}/${capnp_src_file}.h) + add_custom_command(OUTPUT ${cpp_file} ${h_file} + COMMAND ${CAPNP_EXE} compile -o${CAPNP_CXX_EXE} ${capnp_src_file} -I ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${capnp_src_file} capnproto-proj + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + # Update *global* communication_capnp_files + set(communication_capnp_files ${communication_capnp_files} ${cpp_file} ${h_file} PARENT_SCOPE) + # Update *global* communication_src_files + set(communication_src_files ${communication_src_files} ${cpp_file} PARENT_SCOPE) +endfunction(add_capnp) + +add_capnp(rpc/messages.capnp) + +add_custom_target(generate_communication_capnp DEPENDS ${communication_capnp_files}) + +add_library(mg-communication STATIC ${communication_src_files}) +target_link_libraries(mg-communication Threads::Threads mg-utils mg-io fmt glog gflags) +target_link_libraries(mg-communication ${OPENSSL_LIBRARIES}) +target_include_directories(mg-communication SYSTEM PUBLIC ${OPENSSL_INCLUDE_DIR}) +target_link_libraries(mg-communication capnp kj) +add_dependencies(mg-communication generate_communication_capnp) diff --git a/src/communication/bolt/v1/encoder/result_stream.hpp b/src/communication/bolt/v1/encoder/result_stream.hpp index e86b9be3e..4a87aa799 100644 --- a/src/communication/bolt/v1/encoder/result_stream.hpp +++ b/src/communication/bolt/v1/encoder/result_stream.hpp @@ -2,7 +2,6 @@ #include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp" #include "communication/bolt/v1/encoder/encoder.hpp" -#include "communication/conversion.hpp" namespace communication::bolt { @@ -47,20 +46,10 @@ class ResultStream { * * @param values the values that should be sent */ - void Result(std::vector<DecodedValue> &values) { + void Result(const std::vector<DecodedValue> &values) { encoder_.MessageRecord(values); } - // TODO: Move this to another class - void Result(std::vector<query::TypedValue> &values) { - std::vector<DecodedValue> decoded_values; - decoded_values.reserve(values.size()); - for (const auto &v : values) { - decoded_values.push_back(communication::ToDecodedValue(v)); - } - return Result(decoded_values); - } - /** * Writes a summary. Typically a summary is something like: * { diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index 300314b5f..34673897a 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -14,21 +14,10 @@ #include "communication/bolt/v1/states/executing.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" -#include "communication/buffer.hpp" -#include "database/graph_db.hpp" -#include "query/interpreter.hpp" -#include "transactions/transaction.hpp" #include "utils/exceptions.hpp" namespace communication::bolt { -/** Encapsulates Dbms and Interpreter that are passed through the network server - * and worker to the session. */ -struct SessionData { - database::MasterBase &db; - query::Interpreter interpreter{db}; -}; - /** * Bolt Session Exception * @@ -53,17 +42,28 @@ class Session { using ResultStreamT = ResultStream<Encoder<ChunkedEncoderBuffer<TOutputStream>>>; - Session(SessionData &data, TInputStream &input_stream, - TOutputStream &output_stream) - : db_(data.db), - interpreter_(data.interpreter), - input_stream_(input_stream), - output_stream_(output_stream) {} + Session(TInputStream &input_stream, TOutputStream &output_stream) + : input_stream_(input_stream), output_stream_(output_stream) {} - ~Session() { - if (db_accessor_) { - Abort(); - } + virtual ~Session() {} + + /** Return `true` if we are no longer running and accepting queries */ + virtual bool IsShuttingDown() = 0; + + /** + * Put results in the `result_stream` by processing the given `query` with + * `params`. + */ + virtual void PullAll(const std::string &query, + const std::map<std::string, DecodedValue> ¶ms, + ResultStreamT *result_stream) = 0; + + /** Aborts currently running query. */ + virtual void Abort() = 0; + + void PullAll(const std::string &query, + const std::map<std::string, DecodedValue> ¶ms) { + return PullAll(query, params, &result_stream_); } /** @@ -99,11 +99,9 @@ class Session { break; case State::Idle: case State::Result: - case State::WaitForRollback: state_ = StateExecutingRun(*this, state_); break; - case State::ErrorIdle: - case State::ErrorWaitForRollback: + case State::Error: state_ = StateErrorRun(*this, state_); break; default: @@ -122,28 +120,8 @@ class Session { } } - /** - * Commits associated transaction. - */ - void Commit() { - DCHECK(db_accessor_) << "Commit called and there is no transaction"; - db_accessor_->Commit(); - db_accessor_ = nullptr; - } - - /** - * Aborts associated transaction. - */ - void Abort() { - DCHECK(db_accessor_) << "Abort called and there is no transaction"; - db_accessor_->Abort(); - db_accessor_ = nullptr; - } - // TODO: Rethink if there is a way to hide some members. At the momement all // of them are public. - database::MasterBase &db_; - query::Interpreter &interpreter_; TInputStream &input_stream_; TOutputStream &output_stream_; @@ -156,9 +134,6 @@ class Session { bool handshake_done_{false}; State state_{State::Handshake}; - // GraphDbAccessor of active transaction in the session, can be null if - // there is no associated transaction. - std::unique_ptr<database::GraphDbAccessor> db_accessor_; private: void ClientFailureInvalidData() { @@ -176,4 +151,5 @@ class Session { throw SessionException("Something went wrong during session execution!"); } }; + } // namespace communication::bolt diff --git a/src/communication/bolt/v1/state.hpp b/src/communication/bolt/v1/state.hpp index b848ca6cd..b29b506f6 100644 --- a/src/communication/bolt/v1/state.hpp +++ b/src/communication/bolt/v1/state.hpp @@ -31,23 +31,11 @@ enum class State : uint8_t { */ Result, - /** - * There was an acked error in explicitly started transaction, now we are - * waiting for "ROLLBACK" in RUN command. - */ - WaitForRollback, - /** * This state handles errors, if client handles error response correctly next * state is Idle. */ - ErrorIdle, - - /** - * This state handles errors, if client handles error response correctly next - * state is WaitForRollback. - */ - ErrorWaitForRollback, + Error, /** * This is a 'virtual' state (it doesn't have a run function) which tells diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index e620991cc..d1c98f8ab 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -40,20 +40,13 @@ State StateErrorRun(TSession &session, State state) { return State::Close; } if (signature == Signature::Reset) { - if (session.db_accessor_) { - session.Abort(); - } + session.Abort(); return State::Idle; } // We got AckFailure get back to right state. - if (state == State::ErrorIdle) { - return State::Idle; - } else if (state == State::ErrorWaitForRollback) { - return State::WaitForRollback; - } else { - LOG(FATAL) << "Shouldn't happen"; - } + CHECK(state == State::Error) << "Shouldn't happen"; + return State::Idle; } else { uint8_t value = utils::UnderlyingCast(marker); diff --git a/src/communication/bolt/v1/states/executing.hpp b/src/communication/bolt/v1/states/executing.hpp index b687b923a..e7174f973 100644 --- a/src/communication/bolt/v1/states/executing.hpp +++ b/src/communication/bolt/v1/states/executing.hpp @@ -9,15 +9,23 @@ #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/decoder/decoded_value.hpp" #include "communication/bolt/v1/state.hpp" -#include "communication/conversion.hpp" -#include "database/graph_db.hpp" -#include "distributed/pull_rpc_clients.hpp" -#include "query/exceptions.hpp" -#include "query/typed_value.hpp" #include "utils/exceptions.hpp" namespace communication::bolt { +/** + * Used to indicate something is wrong with the client but the transaction is + * kept open for a potential retry. + * + * The most common use case for throwing this error is if something is wrong + * with the query. Perhaps a simple syntax error that can be fixed and query + * retried. + */ +class ClientError : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + template <typename TSession> State HandleRun(TSession &session, State state, Marker marker) { const std::map<std::string, DecodedValue> kEmptyFields = { @@ -41,23 +49,6 @@ State HandleRun(TSession &session, State state, Marker marker) { return State::Close; } - if (state == State::WaitForRollback) { - if (query.ValueString() == "ROLLBACK") { - session.Abort(); - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; - } - DLOG(WARNING) << "Expected RUN \"ROLLBACK\" not received!"; - // Client could potentially recover if we move to error state, but we don't - // implement rollback of single command in transaction, only rollback of - // whole transaction so we can't continue in this transaction if we receive - // new RUN command. - return State::Close; - } - if (state != State::Idle) { // Client could potentially recover if we move to error state, but there is // no legitimate situation in which well working client would end up in this @@ -70,79 +61,20 @@ State HandleRun(TSession &session, State state, Marker marker) { << "There should be no data to write in this state"; DLOG(INFO) << fmt::format("[Run] '{}'", query.ValueString()); - bool in_explicit_transaction = false; - if (session.db_accessor_) { - // Transaction already exists. - in_explicit_transaction = true; - } else { - // TODO: Possible (but very unlikely) race condition, where we have alive - // session during shutdown, but is_accepting_transactions_ isn't yet false. - // We should probably create transactions under some locking mechanism. - if (!session.db_.is_accepting_transactions()) { - // Db is shutting down and doesn't accept new transactions so we should - // close this session. - return State::Close; - } - // Create new transaction. - session.db_accessor_ = - std::make_unique<database::GraphDbAccessor>(session.db_); - } - // If there was not explicitly started transaction before maybe we are - // starting one now. - if (!in_explicit_transaction && query.ValueString() == "BEGIN") { - // Check if query string is "BEGIN". If it is then we should start - // transaction and wait for in-transaction queries. - // TODO: "BEGIN" is not defined by bolt protocol or opencypher so we should - // test if all drivers really denote transaction start with "BEGIN" string. - // Same goes for "ROLLBACK" and "COMMIT". - // - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; + // TODO: Possible (but very unlikely) race condition, where we have alive + // session during shutdown, but IsAcceptingTransactions isn't yet false. + // We should probably create transactions under some locking mechanism. + if (session.IsShuttingDown()) { + // Db is shutting down and doesn't accept new transactions so we should + // close this session. + return State::Close; } try { - // This check is within try block because AdvanceCommand can throw. - if (in_explicit_transaction) { - if (query.ValueString() == "COMMIT") { - session.Commit(); - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; - } else if (query.ValueString() == "ROLLBACK") { - session.Abort(); - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; - } - session.db_accessor_->AdvanceCommand(); - if (session.db_.type() == database::GraphDb::Type::DISTRIBUTED_MASTER) { - auto tx_id = session.db_accessor_->transaction_id(); - auto futures = - session.db_.pull_clients().NotifyAllTransactionCommandAdvanced( - tx_id); - for (auto &future : futures) future.wait(); - } - } + // PullAll can throw. + session.PullAll(query.ValueString(), params.ValueMap()); - std::map<std::string, query::TypedValue> params_tv; - for (const auto &kv : params.ValueMap()) - params_tv.emplace(kv.first, communication::ToTypedValue(kv.second)); - session - .interpreter_(query.ValueString(), *session.db_accessor_, params_tv, - in_explicit_transaction) - .PullAll(session.result_stream_); - - if (!in_explicit_transaction) { - session.Commit(); - } // The query engine has already stored all query data in the buffer. // We should only send the first chunk now which is the success // message which contains header data. The rest of this data (records @@ -161,7 +93,7 @@ State HandleRun(TSession &session, State state, Marker marker) { } auto code_message = [&e]() -> std::pair<std::string, std::string> { - if (dynamic_cast<const query::QueryException *>(&e)) { + if (dynamic_cast<const ClientError *>(&e)) { // Clients expect 4 strings separated by dots. First being database name // (for example: Neo, Memgraph...), second being either ClientError, // TransientError or DatabaseError (or ClientNotification for warnings). @@ -177,10 +109,6 @@ State HandleRun(TSession &session, State state, Marker marker) { // receives *.TransientError.Transaction.Terminate it will not rerun // query even though TransientError was returned, because of Neo's // semantics of that error. - // - // QueryException was thrown, only changing the query or existing - // database data could make this query successful. Return ClientError to - // discourage retry of this query. return {"Memgraph.ClientError.MemgraphError.MemgraphError", e.what()}; } if (dynamic_cast<const utils::BasicException *>(&e)) { @@ -212,11 +140,7 @@ State HandleRun(TSession &session, State state, Marker marker) { DLOG(WARNING) << "Couldn't send failure message!"; return State::Close; } - if (!in_explicit_transaction) { - session.Abort(); - return State::ErrorIdle; - } - return State::ErrorWaitForRollback; + return State::Error; } } @@ -290,9 +214,7 @@ State HandleReset(Session &session, State, Marker marker) { DLOG(WARNING) << "Couldn't send success message!"; return State::Close; } - if (session.db_accessor_) { - session.Abort(); - } + session.Abort(); return State::Idle; } diff --git a/src/communication/bolt/v1/states/handshake.hpp b/src/communication/bolt/v1/states/handshake.hpp index 441d989dd..6267c56ee 100644 --- a/src/communication/bolt/v1/states/handshake.hpp +++ b/src/communication/bolt/v1/states/handshake.hpp @@ -5,6 +5,7 @@ #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/constants.hpp" #include "communication/bolt/v1/state.hpp" +#include "utils/likely.hpp" namespace communication::bolt { diff --git a/src/communication/result_stream_faker.hpp b/src/communication/result_stream_faker.hpp index 7a0a919b3..f52d8f3bf 100644 --- a/src/communication/result_stream_faker.hpp +++ b/src/communication/result_stream_faker.hpp @@ -37,8 +37,7 @@ class ResultStreamFaker { results_.push_back(values); } - void Summary( - const std::map<std::string, communication::bolt::DecodedValue> &summary) { + void Summary(const std::map<std::string, TResultValue> &summary) { DCHECK(current_state_ != State::Done) << "Can only send a summary once"; summary_ = summary; current_state_ = State::Done; @@ -136,5 +135,5 @@ class ResultStreamFaker { // the data that the record stream can accept std::vector<std::string> header_; std::vector<std::vector<TResultValue>> results_; - std::map<std::string, communication::bolt::DecodedValue> summary_; + std::map<std::string, TResultValue> summary_; }; diff --git a/src/database/state_delta.cpp b/src/database/state_delta.cpp index 04972e688..817017583 100644 --- a/src/database/state_delta.cpp +++ b/src/database/state_delta.cpp @@ -3,8 +3,8 @@ #include <string> #include "communication/bolt/v1/decoder/decoded_value.hpp" -#include "communication/conversion.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" namespace database { @@ -208,13 +208,13 @@ void StateDelta::Encode( encoder.WriteInt(vertex_id); encoder.WriteInt(property.Id()); encoder.WriteString(property_name); - encoder.WriteDecodedValue(communication::ToDecodedValue(value)); + encoder.WriteDecodedValue(glue::ToDecodedValue(value)); break; case Type::SET_PROPERTY_EDGE: encoder.WriteInt(edge_id); encoder.WriteInt(property.Id()); encoder.WriteString(property_name); - encoder.WriteDecodedValue(communication::ToDecodedValue(value)); + encoder.WriteDecodedValue(glue::ToDecodedValue(value)); break; case Type::ADD_LABEL: case Type::REMOVE_LABEL: @@ -304,14 +304,14 @@ std::experimental::optional<StateDelta> StateDelta::Decode( DECODE_MEMBER_CAST(property, ValueInt, storage::Property) DECODE_MEMBER(property_name, ValueString) if (!decoder.ReadValue(&dv)) return nullopt; - r_val.value = communication::ToPropertyValue(dv); + r_val.value = glue::ToPropertyValue(dv); break; case Type::SET_PROPERTY_EDGE: DECODE_MEMBER(edge_id, ValueInt) DECODE_MEMBER_CAST(property, ValueInt, storage::Property) DECODE_MEMBER(property_name, ValueString) if (!decoder.ReadValue(&dv)) return nullopt; - r_val.value = communication::ToPropertyValue(dv); + r_val.value = glue::ToPropertyValue(dv); break; case Type::ADD_LABEL: case Type::REMOVE_LABEL: diff --git a/src/durability/recovery.cpp b/src/durability/recovery.cpp index ddab0b15e..a77de5228 100644 --- a/src/durability/recovery.cpp +++ b/src/durability/recovery.cpp @@ -4,7 +4,6 @@ #include <limits> #include <unordered_map> -#include "communication/conversion.hpp" #include "database/graph_db_accessor.hpp" #include "database/indexes/label_property_index.hpp" #include "durability/hashed_file_reader.hpp" @@ -13,6 +12,7 @@ #include "durability/snapshot_decoder.hpp" #include "durability/version.hpp" #include "durability/wal.hpp" +#include "glue/conversion.hpp" #include "query/typed_value.hpp" #include "storage/address_types.hpp" #include "transactions/type.hpp" @@ -127,15 +127,13 @@ bool RecoverSnapshot(const fs::path &snapshot_file, database::GraphDb &db, auto vertex = decoder.ReadSnapshotVertex(); RETURN_IF_NOT(vertex); - auto vertex_accessor = - dba.InsertVertex(vertex->gid, vertex->cypher_id); + auto vertex_accessor = dba.InsertVertex(vertex->gid, vertex->cypher_id); for (const auto &label : vertex->labels) { vertex_accessor.add_label(dba.Label(label)); } for (const auto &property_pair : vertex->properties) { - vertex_accessor.PropsSet( - dba.Property(property_pair.first), - communication::ToTypedValue(property_pair.second)); + vertex_accessor.PropsSet(dba.Property(property_pair.first), + glue::ToTypedValue(property_pair.second)); } auto vertex_record = vertex_accessor.GetNew(); for (const auto &edge : vertex->in) { @@ -205,7 +203,7 @@ bool RecoverSnapshot(const fs::path &snapshot_file, database::GraphDb &db, for (const auto &property_pair : edge.properties) edge_accessor.PropsSet(dba.Property(property_pair.first), - communication::ToTypedValue(property_pair.second)); + glue::ToTypedValue(property_pair.second)); } // Vertex and edge counts are included in the hash. Re-read them to update the diff --git a/src/durability/snapshooter.cpp b/src/durability/snapshooter.cpp index 8258ec8d8..af7624498 100644 --- a/src/durability/snapshooter.cpp +++ b/src/durability/snapshooter.cpp @@ -67,7 +67,7 @@ bool Encode(const fs::path &snapshot_file, database::GraphDb &db, vertex_num++; } for (const auto &edge : dba.Edges(false)) { - encoder.WriteEdge(communication::ToDecodedEdge(edge)); + encoder.WriteEdge(glue::ToDecodedEdge(edge)); encoder.WriteInt(edge.cypher_id()); edge_num++; } diff --git a/src/durability/snapshot_encoder.hpp b/src/durability/snapshot_encoder.hpp index 8edd284d3..9180c747f 100644 --- a/src/durability/snapshot_encoder.hpp +++ b/src/durability/snapshot_encoder.hpp @@ -1,8 +1,8 @@ #pragma once #include "communication/bolt/v1/encoder/base_encoder.hpp" -#include "communication/conversion.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" #include "utils/cast.hpp" namespace durability { @@ -14,7 +14,7 @@ class SnapshotEncoder : public communication::bolt::BaseEncoder<Buffer> { : communication::bolt::BaseEncoder<Buffer>(buffer) {} void WriteSnapshotVertex(const VertexAccessor &vertex) { communication::bolt::BaseEncoder<Buffer>::WriteVertex( - communication::ToDecodedVertex(vertex)); + glue::ToDecodedVertex(vertex)); // Write cypher_id this->WriteInt(vertex.cypher_id()); diff --git a/src/communication/conversion.cpp b/src/glue/conversion.cpp similarity index 98% rename from src/communication/conversion.cpp rename to src/glue/conversion.cpp index 26211c2ba..1411f0574 100644 --- a/src/communication/conversion.cpp +++ b/src/glue/conversion.cpp @@ -1,4 +1,4 @@ -#include "communication/conversion.hpp" +#include "glue/conversion.hpp" #include <map> #include <string> @@ -8,7 +8,7 @@ using communication::bolt::DecodedValue; -namespace communication { +namespace glue { query::TypedValue ToTypedValue(const DecodedValue &value) { switch (value.type()) { @@ -189,4 +189,4 @@ DecodedValue ToDecodedValue(const PropertyValue &value) { } } -} // namespace communication +} // namespace glue diff --git a/src/communication/conversion.hpp b/src/glue/conversion.hpp similarity index 93% rename from src/communication/conversion.hpp rename to src/glue/conversion.hpp index 2c16facb1..9f2ff8ab1 100644 --- a/src/communication/conversion.hpp +++ b/src/glue/conversion.hpp @@ -5,7 +5,7 @@ #include "query/typed_value.hpp" #include "storage/property_value.hpp" -namespace communication { +namespace glue { communication::bolt::DecodedVertex ToDecodedVertex( const VertexAccessor &vertex); @@ -23,4 +23,4 @@ communication::bolt::DecodedValue ToDecodedValue(const PropertyValue &value); PropertyValue ToPropertyValue(const communication::bolt::DecodedValue &value); -} // namespace communication +} // namespace glue diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index 56b231c6a..af850020a 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -14,6 +14,10 @@ #include "communication/bolt/v1/session.hpp" #include "config.hpp" #include "database/graph_db.hpp" +#include "distributed/pull_rpc_clients.hpp" +#include "glue/conversion.hpp" +#include "query/exceptions.hpp" +#include "query/interpreter.hpp" #include "stats/stats.hpp" #include "telemetry/telemetry.hpp" #include "utils/flag_validation.hpp" @@ -24,12 +28,6 @@ // Common stuff for enterprise and community editions -using communication::bolt::SessionData; -using SessionT = communication::bolt::Session<communication::InputStream, - communication::OutputStream>; -using ServerT = communication::Server<SessionT, SessionData>; -using communication::ServerContext; - // General purpose flags. DEFINE_string(interface, "0.0.0.0", "Communication interface on which to listen."); @@ -59,6 +57,173 @@ DEFINE_bool(telemetry_enabled, false, "to allow for easier improvement of the product."); DECLARE_string(durability_directory); +/** Encapsulates Dbms and Interpreter that are passed through the network server + * and worker to the session. */ +struct SessionData { + database::MasterBase &db; + query::Interpreter interpreter{db}; +}; + +class BoltSession final + : public communication::bolt::Session<communication::InputStream, + communication::OutputStream> { + public: + BoltSession(SessionData &data, communication::InputStream &input_stream, + communication::OutputStream &output_stream) + : communication::bolt::Session<communication::InputStream, + communication::OutputStream>( + input_stream, output_stream), + db_(data.db), + interpreter_(data.interpreter) {} + + ~BoltSession() { + if (db_accessor_) { + Abort(); + } + } + + bool IsShuttingDown() override { return !db_.is_accepting_transactions(); } + + using communication::bolt::Session< + communication::InputStream, communication::OutputStream>::ResultStreamT; + + void PullAll( + const std::string &query, + const std::map<std::string, communication::bolt::DecodedValue> ¶ms, + ResultStreamT *result_stream) override { + bool in_explicit_transaction = !!db_accessor_; + if (!db_accessor_) + db_accessor_ = std::make_unique<database::GraphDbAccessor>(db_); + // TODO: Queries below should probably move to interpreter, but there is + // only one interpreter in GraphDb. We probably need some kind of + // per-session access for the interpreter. + // TODO: Also, write tests for these queries + if (expect_rollback_ && query != "ROLLBACK") { + // Client could potentially recover if we move to error state, but we + // don't implement rollback of single command in transaction, only + // rollback of whole transaction so we can't continue in this transaction + // if we receive new query. + throw communication::bolt::ClientError( + "Expected ROLLBACK, because previous query contained an error"); + } + if (query == "ROLLBACK") { + if (!in_explicit_transaction) + throw communication::bolt::ClientError( + "ROLLBACK can only be used after BEGIN"); + Abort(); + result_stream->Header({}); + result_stream->Summary({}); + return; + } else if (query == "BEGIN") { + if (in_explicit_transaction) + throw communication::bolt::ClientError("BEGIN already called"); + // We accept BEGIN, so send the empty results. + result_stream->Header({}); + result_stream->Summary({}); + return; + } else if (query == "COMMIT") { + if (!in_explicit_transaction) + throw communication::bolt::ClientError( + "COMMIT can only be used after BEGIN"); + Commit(); + result_stream->Header({}); + result_stream->Summary({}); + return; + } + // Any other query in BEGIN block advances the command. + if (in_explicit_transaction) AdvanceCommand(); + // Handle regular Cypher queries below + std::map<std::string, query::TypedValue> params_tv; + for (const auto &kv : params) + params_tv.emplace(kv.first, glue::ToTypedValue(kv.second)); + auto abort_tx = [this, in_explicit_transaction]() { + if (in_explicit_transaction) + expect_rollback_ = true; + else + this->Abort(); + }; + try { + TypedValueResultStream stream(result_stream); + interpreter_(query, *db_accessor_, params_tv, in_explicit_transaction) + .PullAll(stream); + if (!in_explicit_transaction) Commit(); + } catch (const query::QueryException &e) { + abort_tx(); + // Wrap QueryException into ClientError, because we want to allow the + // client to fix their query. + throw communication::bolt::ClientError(e.what()); + } catch (const utils::BasicException &) { + abort_tx(); + throw; + } + } + + void Abort() override { + if (!db_accessor_) return; + db_accessor_->Abort(); + db_accessor_ = nullptr; + } + + private: + // Wrapper around ResultStreamT which converts TypedValue to DecodedValue + // before forwarding the calls to original ResultStreamT. + class TypedValueResultStream { + public: + TypedValueResultStream(ResultStreamT *result_stream) + : result_stream_(result_stream) {} + + void Header(const std::vector<std::string> &fields) { + return result_stream_->Header(fields); + } + + void Result(const std::vector<query::TypedValue> &values) { + std::vector<communication::bolt::DecodedValue> decoded_values; + decoded_values.reserve(values.size()); + for (const auto &v : values) { + decoded_values.push_back(glue::ToDecodedValue(v)); + } + return result_stream_->Result(decoded_values); + } + + void Summary(const std::map<std::string, query::TypedValue> &summary) { + std::map<std::string, communication::bolt::DecodedValue> decoded_summary; + for (const auto &kv : summary) { + decoded_summary.emplace(kv.first, glue::ToDecodedValue(kv.second)); + } + return result_stream_->Summary(decoded_summary); + } + + private: + ResultStreamT *result_stream_; + }; + + database::MasterBase &db_; + query::Interpreter &interpreter_; + // GraphDbAccessor of active transaction in the session, can be null if + // there is no associated transaction. + std::unique_ptr<database::GraphDbAccessor> db_accessor_; + bool expect_rollback_{false}; + + void Commit() { + DCHECK(db_accessor_) << "Commit called and there is no transaction"; + db_accessor_->Commit(); + db_accessor_ = nullptr; + } + + void AdvanceCommand() { + db_accessor_->AdvanceCommand(); + if (db_.type() == database::GraphDb::Type::DISTRIBUTED_MASTER) { + auto tx_id = db_accessor_->transaction_id(); + auto futures = + db_.pull_clients().NotifyAllTransactionCommandAdvanced(tx_id); + for (auto &future : futures) future.wait(); + } + } +}; + +using ServerT = communication::Server<BoltSession, SessionData>; +using communication::ServerContext; + // Needed to correctly handle memgraph destruction from a signal handler. // Without having some sort of a flag, it is possible that a signal is handled // when we are exiting main, inside destructors of database::GraphDb and diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 74a4be1fc..5c5df863b 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -98,7 +98,7 @@ Interpreter::Results Interpreter::operator()( ctx.symbol_table_ = plan->symbol_table(); - std::map<std::string, communication::bolt::DecodedValue> summary; + std::map<std::string, TypedValue> summary; summary["parsing_time"] = frontend_time.count(); summary["planning_time"] = planning_time.count(); summary["cost_estimate"] = plan->cost(); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index db90fe1af..fae3b95bb 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -2,8 +2,6 @@ #include <gflags/gflags.h> -#include "communication/bolt/v1/encoder/base_encoder.hpp" -#include "communication/conversion.hpp" #include "data_structures/concurrent/concurrent_map.hpp" #include "database/graph_db.hpp" #include "database/graph_db_accessor.hpp" @@ -68,8 +66,7 @@ class Interpreter { Results(Context ctx, std::shared_ptr<CachedPlan> plan, std::unique_ptr<query::plan::Cursor> cursor, std::vector<Symbol> output_symbols, std::vector<std::string> header, - std::map<std::string, communication::bolt::DecodedValue> summary, - PlanCacheT &plan_cache) + std::map<std::string, TypedValue> summary, PlanCacheT &plan_cache) : ctx_(std::move(ctx)), plan_(plan), cursor_(std::move(cursor)), @@ -144,7 +141,7 @@ class Interpreter { bool header_written_{false}; std::vector<std::string> header_; - std::map<std::string, communication::bolt::DecodedValue> summary_; + std::map<std::string, TypedValue> summary_; utils::Timer execution_timer_; // Gets invalidated after if an index has been built. diff --git a/src/storage/pod_buffer.hpp b/src/storage/pod_buffer.hpp index 45c17203e..0326331e4 100644 --- a/src/storage/pod_buffer.hpp +++ b/src/storage/pod_buffer.hpp @@ -1,6 +1,9 @@ #pragma once -#include "communication/bolt/v1/encoder/base_encoder.hpp" +#include <cstdint> +#include <cstring> +#include <string> +#include <vector> namespace storage { diff --git a/src/storage/property_value_store.cpp b/src/storage/property_value_store.cpp index 4f6d5d70e..1c5720db1 100644 --- a/src/storage/property_value_store.cpp +++ b/src/storage/property_value_store.cpp @@ -4,7 +4,8 @@ #include "glog/logging.h" #include "communication/bolt/v1/decoder/decoder.hpp" -#include "communication/conversion.hpp" +#include "communication/bolt/v1/encoder/base_encoder.hpp" +#include "glue/conversion.hpp" #include "storage/pod_buffer.hpp" #include "storage/property_value_store.hpp" @@ -213,7 +214,7 @@ PropertyValueStore::iterator PropertyValueStore::end() const { std::string PropertyValueStore::SerializeProp(const PropertyValue &prop) const { storage::PODBuffer pod_buffer; BaseEncoder<storage::PODBuffer> encoder{pod_buffer}; - encoder.WriteDecodedValue(communication::ToDecodedValue(prop)); + encoder.WriteDecodedValue(glue::ToDecodedValue(prop)); return std::string(reinterpret_cast<char *>(pod_buffer.buffer.data()), pod_buffer.buffer.size()); } @@ -228,7 +229,7 @@ PropertyValue PropertyValueStore::DeserializeProp( DLOG(WARNING) << "Unable to read property value"; return PropertyValue::Null; } - return communication::ToPropertyValue(dv); + return glue::ToPropertyValue(dv); } storage::KVStore PropertyValueStore::ConstructDiskStorage() const { diff --git a/tests/manual/CMakeLists.txt b/tests/manual/CMakeLists.txt index aab46eb77..b53e414b3 100644 --- a/tests/manual/CMakeLists.txt +++ b/tests/manual/CMakeLists.txt @@ -28,7 +28,7 @@ add_manual_test(binomial.cpp) target_link_libraries(${test_prefix}binomial mg-utils) add_manual_test(bolt_client.cpp) -target_link_libraries(${test_prefix}bolt_client memgraph_lib kvstore_dummy_lib) +target_link_libraries(${test_prefix}bolt_client mg-communication) add_manual_test(card_fraud_generate_snapshot.cpp) target_link_libraries(${test_prefix}card_fraud_generate_snapshot memgraph_lib kvstore_dummy_lib) @@ -72,10 +72,10 @@ add_manual_test(stripped_timing.cpp) target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib) add_manual_test(ssl_client.cpp) -target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib) +target_link_libraries(${test_prefix}ssl_client mg-communication) add_manual_test(ssl_server.cpp) -target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib) +target_link_libraries(${test_prefix}ssl_server mg-communication) add_manual_test(xorshift.cpp) target_link_libraries(${test_prefix}xorshift mg-utils) diff --git a/tests/manual/distributed_common.hpp b/tests/manual/distributed_common.hpp index 0ac0d1729..676d90d90 100644 --- a/tests/manual/distributed_common.hpp +++ b/tests/manual/distributed_common.hpp @@ -3,9 +3,9 @@ #include <chrono> #include <vector> -#include "communication/conversion.hpp" #include "communication/result_stream_faker.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" #include "query/interpreter.hpp" #include "query/typed_value.hpp" diff --git a/tests/manual/snapshot_generation/snapshot_writer.hpp b/tests/manual/snapshot_generation/snapshot_writer.hpp index df8881ecd..26e190ef0 100644 --- a/tests/manual/snapshot_generation/snapshot_writer.hpp +++ b/tests/manual/snapshot_generation/snapshot_writer.hpp @@ -4,10 +4,10 @@ #include <vector> #include "communication/bolt/v1/encoder/base_encoder.hpp" -#include "communication/conversion.hpp" #include "durability/hashed_file_writer.hpp" #include "durability/paths.hpp" #include "durability/version.hpp" +#include "glue/conversion.hpp" #include "query/typed_value.hpp" #include "utils/file.hpp" @@ -71,7 +71,7 @@ class SnapshotWriter { WriteList(node.labels); std::map<std::string, communication::bolt::DecodedValue> props; for (const auto &prop : node.props) { - props[prop.first] = communication::ToDecodedValue(prop.second); + props[prop.first] = glue::ToDecodedValue(prop.second); } encoder_.WriteMap(props); @@ -102,7 +102,7 @@ class SnapshotWriter { encoder_.WriteString(edge.type); std::map<std::string, communication::bolt::DecodedValue> props; for (const auto &prop : edge.props) { - props[prop.first] = communication::ToDecodedValue(prop.second); + props[prop.first] = glue::ToDecodedValue(prop.second); } encoder_.WriteMap(props); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 466de9102..f2fef0b19 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -19,27 +19,9 @@ function(add_unit_test test_cpp) add_dependencies(memgraph__unit ${target_name}) endfunction(add_unit_test) -add_unit_test(bolt_chunked_decoder_buffer.cpp) -target_link_libraries(${test_prefix}bolt_chunked_decoder_buffer memgraph_lib kvstore_dummy_lib) - -add_unit_test(bolt_chunked_encoder_buffer.cpp) -target_link_libraries(${test_prefix}bolt_chunked_encoder_buffer memgraph_lib kvstore_dummy_lib) - -add_unit_test(bolt_decoder.cpp) -target_link_libraries(${test_prefix}bolt_decoder memgraph_lib kvstore_dummy_lib) - add_unit_test(bolt_encoder.cpp) target_link_libraries(${test_prefix}bolt_encoder memgraph_lib kvstore_dummy_lib) -add_unit_test(bolt_result_stream.cpp) -target_link_libraries(${test_prefix}bolt_result_stream memgraph_lib kvstore_dummy_lib) - -add_unit_test(bolt_session.cpp) -target_link_libraries(${test_prefix}bolt_session memgraph_lib kvstore_dummy_lib) - -add_unit_test(communication_buffer.cpp) -target_link_libraries(${test_prefix}communication_buffer memgraph_lib kvstore_dummy_lib) - add_unit_test(concurrent_id_mapper_distributed.cpp) target_link_libraries(${test_prefix}concurrent_id_mapper_distributed memgraph_lib kvstore_dummy_lib) @@ -157,9 +139,6 @@ target_link_libraries(${test_prefix}mvcc_one_transaction memgraph_lib kvstore_du add_unit_test(mvcc_parallel_update.cpp) target_link_libraries(${test_prefix}mvcc_parallel_update memgraph_lib kvstore_dummy_lib) -add_unit_test(network_timeouts.cpp) -target_link_libraries(${test_prefix}network_timeouts memgraph_lib kvstore_dummy_lib) - add_unit_test(pod_buffer.cpp) target_link_libraries(${test_prefix}pod_buffer memgraph_lib kvstore_dummy_lib) @@ -208,9 +187,6 @@ target_link_libraries(${test_prefix}raft_storage memgraph_lib kvstore_dummy_lib) add_unit_test(record_edge_vertex_accessor.cpp) target_link_libraries(${test_prefix}record_edge_vertex_accessor memgraph_lib kvstore_dummy_lib) -add_unit_test(rpc.cpp) -target_link_libraries(${test_prefix}rpc memgraph_lib kvstore_dummy_lib) - add_unit_test(rpc_worker_clients.cpp) target_link_libraries(${test_prefix}rpc_worker_clients memgraph_lib kvstore_dummy_lib) @@ -253,6 +229,32 @@ target_link_libraries(${test_prefix}transaction_engine_single_node memgraph_lib add_unit_test(typed_value.cpp) target_link_libraries(${test_prefix}typed_value memgraph_lib kvstore_dummy_lib) +# Test mg-communication + +add_unit_test(bolt_chunked_decoder_buffer.cpp) +target_link_libraries(${test_prefix}bolt_chunked_decoder_buffer mg-communication) + +add_unit_test(bolt_chunked_encoder_buffer.cpp) +target_link_libraries(${test_prefix}bolt_chunked_encoder_buffer mg-communication) + +add_unit_test(bolt_decoder.cpp) +target_link_libraries(${test_prefix}bolt_decoder mg-communication) + +add_unit_test(bolt_result_stream.cpp) +target_link_libraries(${test_prefix}bolt_result_stream mg-communication) + +add_unit_test(bolt_session.cpp) +target_link_libraries(${test_prefix}bolt_session mg-communication) + +add_unit_test(communication_buffer.cpp) +target_link_libraries(${test_prefix}communication_buffer mg-communication) + +add_unit_test(network_timeouts.cpp) +target_link_libraries(${test_prefix}network_timeouts mg-communication) + +add_unit_test(rpc.cpp) +target_link_libraries(${test_prefix}rpc mg-communication) + # Test data structures add_unit_test(ring_buffer.cpp) diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index 199c5b1f0..47dd75b3a 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -6,8 +6,7 @@ #include <vector> #include <glog/logging.h> - -#include "gtest/gtest.h" +#include <gtest/gtest.h> /** * TODO (mferencevic): document diff --git a/tests/unit/bolt_decoder.cpp b/tests/unit/bolt_decoder.cpp index ea3b3d756..64d062673 100644 --- a/tests/unit/bolt_decoder.cpp +++ b/tests/unit/bolt_decoder.cpp @@ -4,7 +4,6 @@ #include "bolt_testdata.hpp" #include "communication/bolt/v1/decoder/decoder.hpp" -#include "query/typed_value.hpp" using communication::bolt::DecodedValue; diff --git a/tests/unit/bolt_encoder.cpp b/tests/unit/bolt_encoder.cpp index 598ec6a04..ee687cf2e 100644 --- a/tests/unit/bolt_encoder.cpp +++ b/tests/unit/bolt_encoder.cpp @@ -2,9 +2,9 @@ #include "bolt_testdata.hpp" #include "communication/bolt/v1/encoder/encoder.hpp" -#include "communication/conversion.hpp" #include "database/graph_db.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" using communication::bolt::DecodedValue; @@ -189,9 +189,9 @@ TEST(BoltEncoder, VertexAndEdge) { // check everything std::vector<DecodedValue> vals; - vals.push_back(communication::ToDecodedValue(va1)); - vals.push_back(communication::ToDecodedValue(va2)); - vals.push_back(communication::ToDecodedValue(ea)); + vals.push_back(glue::ToDecodedValue(va1)); + vals.push_back(glue::ToDecodedValue(va2)); + vals.push_back(glue::ToDecodedValue(ea)); bolt_encoder.MessageRecord(vals); // The vertexedge_encoded testdata has hardcoded zeros for IDs, diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 5cbb8dea5..f8eeb2f9c 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -2,27 +2,60 @@ #include <glog/logging.h> #include "bolt_common.hpp" -#include "communication/bolt/v1/encoder/result_stream.hpp" #include "communication/bolt/v1/session.hpp" -#include "database/graph_db.hpp" + +using communication::bolt::ClientError; +using communication::bolt::DecodedValue; +using communication::bolt::Session; +using communication::bolt::SessionException; +using communication::bolt::State; + +static const char *kInvalidQuery = "invalid query"; +static const char *kQueryReturn42 = "RETURN 42"; +static const char *kQueryEmpty = "no results"; + +class TestSessionData {}; + +class TestSession : public Session<TestInputStream, TestOutputStream> { + public: + using Session<TestInputStream, TestOutputStream>::ResultStreamT; + + TestSession(TestSessionData &data, TestInputStream &input_stream, + TestOutputStream &output_stream) + : Session<TestInputStream, TestOutputStream>(input_stream, + output_stream) {} + + bool IsShuttingDown() override { return false; } + + void PullAll(const std::string &query, + const std::map<std::string, DecodedValue> ¶ms, + ResultStreamT *result_stream) override { + if (query == kQueryReturn42) { + result_stream->Header({"result_name"}); + result_stream->Result(std::vector<DecodedValue>{42}); + result_stream->Summary({}); + } else if (query == kQueryEmpty) { + result_stream->Header({"result_name"}); + result_stream->Summary({}); + } else { + throw ClientError("client sent invalid query"); + } + } + + void Abort() override {} +}; + +using ResultStreamT = TestSession::ResultStreamT; // TODO: This could be done in fixture. // Shortcuts for writing variable initializations in tests -#define INIT_VARS \ - TestInputStream input_stream; \ - TestOutputStream output_stream; \ - database::SingleNode db; \ - SessionData session_data{db}; \ - SessionT session(session_data, input_stream, output_stream); \ +#define INIT_VARS \ + TestInputStream input_stream; \ + TestOutputStream output_stream; \ + TestSessionData session_data; \ + TestSession session(session_data, input_stream, output_stream); \ std::vector<uint8_t> &output = output_stream.output; -using communication::bolt::SessionData; -using communication::bolt::SessionException; -using communication::bolt::State; -using SessionT = - communication::bolt::Session<TestInputStream, TestOutputStream>; -using ResultStreamT = SessionT::ResultStreamT; - // Sample testdata that has correct inputs and outputs. const uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -85,7 +118,7 @@ void CheckIgnoreMessage(std::vector<uint8_t> &output) { } // Execute and check a correct handshake -void ExecuteHandshake(TestInputStream &input_stream, SessionT &session, +void ExecuteHandshake(TestInputStream &input_stream, TestSession &session, std::vector<uint8_t> &output) { input_stream.Write(handshake_req, 20); session.Execute(); @@ -95,7 +128,7 @@ void ExecuteHandshake(TestInputStream &input_stream, SessionT &session, } // Write bolt chunk and execute command -void ExecuteCommand(TestInputStream &input_stream, SessionT &session, +void ExecuteCommand(TestInputStream &input_stream, TestSession &session, const uint8_t *data, size_t len, bool chunk = true) { if (chunk) WriteChunkHeader(input_stream, len); input_stream.Write(data, len); @@ -104,7 +137,7 @@ void ExecuteCommand(TestInputStream &input_stream, SessionT &session, } // Execute and check a correct init -void ExecuteInit(TestInputStream &input_stream, SessionT &session, +void ExecuteInit(TestInputStream &input_stream, TestSession &session, std::vector<uint8_t> &output) { ExecuteCommand(input_stream, session, init_req, sizeof(init_req)); ASSERT_EQ(session.state_, State::Idle); @@ -277,7 +310,7 @@ TEST(BoltSession, ExecuteRunBasicException) { ExecuteInit(input_stream, session, output); output_stream.SetWriteSuccess(i == 0); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); if (i == 0) { session.Execute(); } else { @@ -285,7 +318,7 @@ TEST(BoltSession, ExecuteRunBasicException) { } if (i == 0) { - ASSERT_EQ(session.state_, State::ErrorIdle); + ASSERT_EQ(session.state_, State::Error); CheckFailureMessage(output); } else { ASSERT_EQ(session.state_, State::Close); @@ -300,7 +333,7 @@ TEST(BoltSession, ExecuteRunWithoutPullAll) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "RETURN 2"); + WriteRunRequest(input_stream, kQueryReturn42); session.Execute(); ASSERT_EQ(session.state_, State::Result); @@ -361,7 +394,7 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "CREATE (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); session.Execute(); if (j == 1) output.clear(); @@ -407,7 +440,7 @@ TEST(BoltSession, ErrorIgnoreMessage) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -425,7 +458,7 @@ TEST(BoltSession, ErrorIgnoreMessage) { ASSERT_EQ(session.decoder_buffer_.Size(), 0); if (i == 0) { - ASSERT_EQ(session.state_, State::ErrorIdle); + ASSERT_EQ(session.state_, State::Error); CheckOutput(output, ignored_resp, sizeof(ignored_resp)); } else { ASSERT_EQ(session.state_, State::Close); @@ -441,7 +474,7 @@ TEST(BoltSession, ErrorRunAfterRun) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); session.Execute(); output.clear(); @@ -452,7 +485,7 @@ TEST(BoltSession, ErrorRunAfterRun) { ASSERT_EQ(session.state_, State::Result); // New run request. - WriteRunRequest(input_stream, "MATCH (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); ASSERT_THROW(session.Execute(), SessionException); ASSERT_EQ(session.state_, State::Close); @@ -464,7 +497,7 @@ TEST(BoltSession, ErrorCantCleanup) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -484,7 +517,7 @@ TEST(BoltSession, ErrorWrongMarker) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -510,7 +543,7 @@ TEST(BoltSession, ErrorOK) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -543,7 +576,7 @@ TEST(BoltSession, ErrorMissingData) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -563,7 +596,7 @@ TEST(BoltSession, MultipleChunksInOneExecute) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "CREATE (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); ASSERT_EQ(session.state_, State::Idle); @@ -607,117 +640,6 @@ TEST(BoltSession, PartialChunk) { PrintOutput(output); } -TEST(BoltSession, ExplicitTransactionValidQueries) { - // It is not really easy to check if we commited or aborted transaction except - // by faking GraphDb/TxEngine... - std::vector<std::string> transaction_ends = {"COMMIT", "ROLLBACK"}; - - for (const auto &transaction_end : transaction_ends) { - INIT_VARS; - - ExecuteHandshake(input_stream, session, output); - ExecuteInit(input_stream, session, output); - - WriteRunRequest(input_stream, "BEGIN"); - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, "MATCH (n) RETURN n"); - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, transaction_end.c_str()); - session.Execute(); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - ASSERT_EQ(session.state_, State::Result); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - } -} - -TEST(BoltSession, ExplicitTransactionInvalidQuery) { - std::vector<std::string> transaction_ends = {"COMMIT", "ROLLBACK"}; - - for (const auto &transaction_end : transaction_ends) { - INIT_VARS; - - ExecuteHandshake(input_stream, session, output); - ExecuteInit(input_stream, session, output); - - WriteRunRequest(input_stream, "BEGIN"); - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, "MATCH ("); - session.Execute(); - ASSERT_EQ(session.state_, State::ErrorWaitForRollback); - ASSERT_TRUE(session.db_accessor_); - CheckFailureMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::ErrorWaitForRollback); - ASSERT_TRUE(session.db_accessor_); - CheckIgnoreMessage(output); - - ExecuteCommand(input_stream, session, ackfailure_req, - sizeof(ackfailure_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::WaitForRollback); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, transaction_end.c_str()); - - if (transaction_end == "ROLLBACK") { - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - - } else { - ASSERT_THROW(session.Execute(), SessionException); - ASSERT_EQ(session.state_, State::Close); - CheckFailureMessage(output); - } - } -} - int main(int argc, char **argv) { google::InitGoogleLogging(argv[0]); ::testing::InitGoogleTest(&argc, argv);