diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 925191344..0a495230b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,19 +5,10 @@ add_subdirectory(utils) add_subdirectory(integrations) add_subdirectory(io) add_subdirectory(telemetry) +add_subdirectory(communication) # all memgraph src files set(memgraph_src_files - communication/bolt/v1/decoder/decoded_value.cpp - communication/buffer.cpp - communication/client.cpp - communication/context.cpp - communication/conversion.cpp - communication/helpers.cpp - communication/init.cpp - communication/rpc/client.cpp - communication/rpc/protocol.cpp - communication/rpc/server.cpp data_structures/concurrent/skiplist_gc.cpp database/config.cpp database/counters.cpp @@ -49,6 +40,7 @@ set(memgraph_src_files durability/recovery.cpp durability/snapshooter.cpp durability/wal.cpp + glue/conversion.cpp query/common.cpp query/frontend/ast/ast.cpp query/frontend/ast/cypher_main_visitor.cpp @@ -195,10 +187,9 @@ string(TOLOWER ${CMAKE_BUILD_TYPE} lower_build_type) # memgraph_lib depend on these libraries set(MEMGRAPH_ALL_LIBS stdc++fs Threads::Threads fmt cppitertools antlr_opencypher_parser_lib dl glog gflags capnp kj - ${OPENSSL_LIBRARIES} ${Boost_IOSTREAMS_LIBRARY_RELEASE} ${Boost_SERIALIZATION_LIBRARY_RELEASE} - mg-utils mg-io mg-integrations) + mg-utils mg-io mg-integrations mg-communication) if (USE_LTALLOC) list(APPEND MEMGRAPH_ALL_LIBS ltalloc) @@ -213,7 +204,6 @@ endif() # STATIC library used by memgraph executables add_library(memgraph_lib STATIC ${memgraph_src_files}) target_link_libraries(memgraph_lib ${MEMGRAPH_ALL_LIBS}) -target_include_directories(memgraph_lib PRIVATE ${OPENSSL_INCLUDE_DIR}) add_dependencies(memgraph_lib generate_opencypher_parser) add_dependencies(memgraph_lib generate_lcp) add_dependencies(memgraph_lib generate_capnp) diff --git a/src/communication/CMakeLists.txt b/src/communication/CMakeLists.txt new file mode 100644 index 000000000..e41576c4b --- /dev/null +++ b/src/communication/CMakeLists.txt @@ -0,0 +1,44 @@ +set(communication_src_files + bolt/v1/decoder/decoded_value.cpp + buffer.cpp + client.cpp + context.cpp + helpers.cpp + init.cpp + rpc/client.cpp + rpc/protocol.cpp + rpc/server.cpp) + +# TODO: Extract data_structures to library +set(communication_src_files ${communication_src_files} + ${CMAKE_SOURCE_DIR}/src/data_structures/concurrent/skiplist_gc.cpp) + +# Use this function to add each capnp file to generation. This way each file is +# standalone and we avoid recompiling everything. +# NOTE: communication_src_files and communication_capnp_files are globally updated. +# TODO: This is duplicated from src/CMakeLists.txt and +# src/utils/CMakeLists.txt, find a good way to generalize this on per +# subdirectory basis. +function(add_capnp capnp_src_file) + set(cpp_file ${CMAKE_CURRENT_SOURCE_DIR}/${capnp_src_file}.c++) + set(h_file ${CMAKE_CURRENT_SOURCE_DIR}/${capnp_src_file}.h) + add_custom_command(OUTPUT ${cpp_file} ${h_file} + COMMAND ${CAPNP_EXE} compile -o${CAPNP_CXX_EXE} ${capnp_src_file} -I ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${capnp_src_file} capnproto-proj + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + # Update *global* communication_capnp_files + set(communication_capnp_files ${communication_capnp_files} ${cpp_file} ${h_file} PARENT_SCOPE) + # Update *global* communication_src_files + set(communication_src_files ${communication_src_files} ${cpp_file} PARENT_SCOPE) +endfunction(add_capnp) + +add_capnp(rpc/messages.capnp) + +add_custom_target(generate_communication_capnp DEPENDS ${communication_capnp_files}) + +add_library(mg-communication STATIC ${communication_src_files}) +target_link_libraries(mg-communication Threads::Threads mg-utils mg-io fmt glog gflags) +target_link_libraries(mg-communication ${OPENSSL_LIBRARIES}) +target_include_directories(mg-communication SYSTEM PUBLIC ${OPENSSL_INCLUDE_DIR}) +target_link_libraries(mg-communication capnp kj) +add_dependencies(mg-communication generate_communication_capnp) diff --git a/src/communication/bolt/v1/encoder/result_stream.hpp b/src/communication/bolt/v1/encoder/result_stream.hpp index e86b9be3e..4a87aa799 100644 --- a/src/communication/bolt/v1/encoder/result_stream.hpp +++ b/src/communication/bolt/v1/encoder/result_stream.hpp @@ -2,7 +2,6 @@ #include "communication/bolt/v1/encoder/chunked_encoder_buffer.hpp" #include "communication/bolt/v1/encoder/encoder.hpp" -#include "communication/conversion.hpp" namespace communication::bolt { @@ -47,20 +46,10 @@ class ResultStream { * * @param values the values that should be sent */ - void Result(std::vector<DecodedValue> &values) { + void Result(const std::vector<DecodedValue> &values) { encoder_.MessageRecord(values); } - // TODO: Move this to another class - void Result(std::vector<query::TypedValue> &values) { - std::vector<DecodedValue> decoded_values; - decoded_values.reserve(values.size()); - for (const auto &v : values) { - decoded_values.push_back(communication::ToDecodedValue(v)); - } - return Result(decoded_values); - } - /** * Writes a summary. Typically a summary is something like: * { diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index 300314b5f..34673897a 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -14,21 +14,10 @@ #include "communication/bolt/v1/states/executing.hpp" #include "communication/bolt/v1/states/handshake.hpp" #include "communication/bolt/v1/states/init.hpp" -#include "communication/buffer.hpp" -#include "database/graph_db.hpp" -#include "query/interpreter.hpp" -#include "transactions/transaction.hpp" #include "utils/exceptions.hpp" namespace communication::bolt { -/** Encapsulates Dbms and Interpreter that are passed through the network server - * and worker to the session. */ -struct SessionData { - database::MasterBase &db; - query::Interpreter interpreter{db}; -}; - /** * Bolt Session Exception * @@ -53,17 +42,28 @@ class Session { using ResultStreamT = ResultStream<Encoder<ChunkedEncoderBuffer<TOutputStream>>>; - Session(SessionData &data, TInputStream &input_stream, - TOutputStream &output_stream) - : db_(data.db), - interpreter_(data.interpreter), - input_stream_(input_stream), - output_stream_(output_stream) {} + Session(TInputStream &input_stream, TOutputStream &output_stream) + : input_stream_(input_stream), output_stream_(output_stream) {} - ~Session() { - if (db_accessor_) { - Abort(); - } + virtual ~Session() {} + + /** Return `true` if we are no longer running and accepting queries */ + virtual bool IsShuttingDown() = 0; + + /** + * Put results in the `result_stream` by processing the given `query` with + * `params`. + */ + virtual void PullAll(const std::string &query, + const std::map<std::string, DecodedValue> ¶ms, + ResultStreamT *result_stream) = 0; + + /** Aborts currently running query. */ + virtual void Abort() = 0; + + void PullAll(const std::string &query, + const std::map<std::string, DecodedValue> ¶ms) { + return PullAll(query, params, &result_stream_); } /** @@ -99,11 +99,9 @@ class Session { break; case State::Idle: case State::Result: - case State::WaitForRollback: state_ = StateExecutingRun(*this, state_); break; - case State::ErrorIdle: - case State::ErrorWaitForRollback: + case State::Error: state_ = StateErrorRun(*this, state_); break; default: @@ -122,28 +120,8 @@ class Session { } } - /** - * Commits associated transaction. - */ - void Commit() { - DCHECK(db_accessor_) << "Commit called and there is no transaction"; - db_accessor_->Commit(); - db_accessor_ = nullptr; - } - - /** - * Aborts associated transaction. - */ - void Abort() { - DCHECK(db_accessor_) << "Abort called and there is no transaction"; - db_accessor_->Abort(); - db_accessor_ = nullptr; - } - // TODO: Rethink if there is a way to hide some members. At the momement all // of them are public. - database::MasterBase &db_; - query::Interpreter &interpreter_; TInputStream &input_stream_; TOutputStream &output_stream_; @@ -156,9 +134,6 @@ class Session { bool handshake_done_{false}; State state_{State::Handshake}; - // GraphDbAccessor of active transaction in the session, can be null if - // there is no associated transaction. - std::unique_ptr<database::GraphDbAccessor> db_accessor_; private: void ClientFailureInvalidData() { @@ -176,4 +151,5 @@ class Session { throw SessionException("Something went wrong during session execution!"); } }; + } // namespace communication::bolt diff --git a/src/communication/bolt/v1/state.hpp b/src/communication/bolt/v1/state.hpp index b848ca6cd..b29b506f6 100644 --- a/src/communication/bolt/v1/state.hpp +++ b/src/communication/bolt/v1/state.hpp @@ -31,23 +31,11 @@ enum class State : uint8_t { */ Result, - /** - * There was an acked error in explicitly started transaction, now we are - * waiting for "ROLLBACK" in RUN command. - */ - WaitForRollback, - /** * This state handles errors, if client handles error response correctly next * state is Idle. */ - ErrorIdle, - - /** - * This state handles errors, if client handles error response correctly next - * state is WaitForRollback. - */ - ErrorWaitForRollback, + Error, /** * This is a 'virtual' state (it doesn't have a run function) which tells diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index e620991cc..d1c98f8ab 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -40,20 +40,13 @@ State StateErrorRun(TSession &session, State state) { return State::Close; } if (signature == Signature::Reset) { - if (session.db_accessor_) { - session.Abort(); - } + session.Abort(); return State::Idle; } // We got AckFailure get back to right state. - if (state == State::ErrorIdle) { - return State::Idle; - } else if (state == State::ErrorWaitForRollback) { - return State::WaitForRollback; - } else { - LOG(FATAL) << "Shouldn't happen"; - } + CHECK(state == State::Error) << "Shouldn't happen"; + return State::Idle; } else { uint8_t value = utils::UnderlyingCast(marker); diff --git a/src/communication/bolt/v1/states/executing.hpp b/src/communication/bolt/v1/states/executing.hpp index b687b923a..e7174f973 100644 --- a/src/communication/bolt/v1/states/executing.hpp +++ b/src/communication/bolt/v1/states/executing.hpp @@ -9,15 +9,23 @@ #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/decoder/decoded_value.hpp" #include "communication/bolt/v1/state.hpp" -#include "communication/conversion.hpp" -#include "database/graph_db.hpp" -#include "distributed/pull_rpc_clients.hpp" -#include "query/exceptions.hpp" -#include "query/typed_value.hpp" #include "utils/exceptions.hpp" namespace communication::bolt { +/** + * Used to indicate something is wrong with the client but the transaction is + * kept open for a potential retry. + * + * The most common use case for throwing this error is if something is wrong + * with the query. Perhaps a simple syntax error that can be fixed and query + * retried. + */ +class ClientError : public utils::BasicException { + public: + using utils::BasicException::BasicException; +}; + template <typename TSession> State HandleRun(TSession &session, State state, Marker marker) { const std::map<std::string, DecodedValue> kEmptyFields = { @@ -41,23 +49,6 @@ State HandleRun(TSession &session, State state, Marker marker) { return State::Close; } - if (state == State::WaitForRollback) { - if (query.ValueString() == "ROLLBACK") { - session.Abort(); - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; - } - DLOG(WARNING) << "Expected RUN \"ROLLBACK\" not received!"; - // Client could potentially recover if we move to error state, but we don't - // implement rollback of single command in transaction, only rollback of - // whole transaction so we can't continue in this transaction if we receive - // new RUN command. - return State::Close; - } - if (state != State::Idle) { // Client could potentially recover if we move to error state, but there is // no legitimate situation in which well working client would end up in this @@ -70,79 +61,20 @@ State HandleRun(TSession &session, State state, Marker marker) { << "There should be no data to write in this state"; DLOG(INFO) << fmt::format("[Run] '{}'", query.ValueString()); - bool in_explicit_transaction = false; - if (session.db_accessor_) { - // Transaction already exists. - in_explicit_transaction = true; - } else { - // TODO: Possible (but very unlikely) race condition, where we have alive - // session during shutdown, but is_accepting_transactions_ isn't yet false. - // We should probably create transactions under some locking mechanism. - if (!session.db_.is_accepting_transactions()) { - // Db is shutting down and doesn't accept new transactions so we should - // close this session. - return State::Close; - } - // Create new transaction. - session.db_accessor_ = - std::make_unique<database::GraphDbAccessor>(session.db_); - } - // If there was not explicitly started transaction before maybe we are - // starting one now. - if (!in_explicit_transaction && query.ValueString() == "BEGIN") { - // Check if query string is "BEGIN". If it is then we should start - // transaction and wait for in-transaction queries. - // TODO: "BEGIN" is not defined by bolt protocol or opencypher so we should - // test if all drivers really denote transaction start with "BEGIN" string. - // Same goes for "ROLLBACK" and "COMMIT". - // - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; + // TODO: Possible (but very unlikely) race condition, where we have alive + // session during shutdown, but IsAcceptingTransactions isn't yet false. + // We should probably create transactions under some locking mechanism. + if (session.IsShuttingDown()) { + // Db is shutting down and doesn't accept new transactions so we should + // close this session. + return State::Close; } try { - // This check is within try block because AdvanceCommand can throw. - if (in_explicit_transaction) { - if (query.ValueString() == "COMMIT") { - session.Commit(); - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; - } else if (query.ValueString() == "ROLLBACK") { - session.Abort(); - // One MessageSuccess for RUN command should be flushed. - session.encoder_.MessageSuccess(kEmptyFields); - // One for PULL_ALL should be chunked. - session.encoder_.MessageSuccess({}, false); - return State::Result; - } - session.db_accessor_->AdvanceCommand(); - if (session.db_.type() == database::GraphDb::Type::DISTRIBUTED_MASTER) { - auto tx_id = session.db_accessor_->transaction_id(); - auto futures = - session.db_.pull_clients().NotifyAllTransactionCommandAdvanced( - tx_id); - for (auto &future : futures) future.wait(); - } - } + // PullAll can throw. + session.PullAll(query.ValueString(), params.ValueMap()); - std::map<std::string, query::TypedValue> params_tv; - for (const auto &kv : params.ValueMap()) - params_tv.emplace(kv.first, communication::ToTypedValue(kv.second)); - session - .interpreter_(query.ValueString(), *session.db_accessor_, params_tv, - in_explicit_transaction) - .PullAll(session.result_stream_); - - if (!in_explicit_transaction) { - session.Commit(); - } // The query engine has already stored all query data in the buffer. // We should only send the first chunk now which is the success // message which contains header data. The rest of this data (records @@ -161,7 +93,7 @@ State HandleRun(TSession &session, State state, Marker marker) { } auto code_message = [&e]() -> std::pair<std::string, std::string> { - if (dynamic_cast<const query::QueryException *>(&e)) { + if (dynamic_cast<const ClientError *>(&e)) { // Clients expect 4 strings separated by dots. First being database name // (for example: Neo, Memgraph...), second being either ClientError, // TransientError or DatabaseError (or ClientNotification for warnings). @@ -177,10 +109,6 @@ State HandleRun(TSession &session, State state, Marker marker) { // receives *.TransientError.Transaction.Terminate it will not rerun // query even though TransientError was returned, because of Neo's // semantics of that error. - // - // QueryException was thrown, only changing the query or existing - // database data could make this query successful. Return ClientError to - // discourage retry of this query. return {"Memgraph.ClientError.MemgraphError.MemgraphError", e.what()}; } if (dynamic_cast<const utils::BasicException *>(&e)) { @@ -212,11 +140,7 @@ State HandleRun(TSession &session, State state, Marker marker) { DLOG(WARNING) << "Couldn't send failure message!"; return State::Close; } - if (!in_explicit_transaction) { - session.Abort(); - return State::ErrorIdle; - } - return State::ErrorWaitForRollback; + return State::Error; } } @@ -290,9 +214,7 @@ State HandleReset(Session &session, State, Marker marker) { DLOG(WARNING) << "Couldn't send success message!"; return State::Close; } - if (session.db_accessor_) { - session.Abort(); - } + session.Abort(); return State::Idle; } diff --git a/src/communication/bolt/v1/states/handshake.hpp b/src/communication/bolt/v1/states/handshake.hpp index 441d989dd..6267c56ee 100644 --- a/src/communication/bolt/v1/states/handshake.hpp +++ b/src/communication/bolt/v1/states/handshake.hpp @@ -5,6 +5,7 @@ #include "communication/bolt/v1/codes.hpp" #include "communication/bolt/v1/constants.hpp" #include "communication/bolt/v1/state.hpp" +#include "utils/likely.hpp" namespace communication::bolt { diff --git a/src/communication/result_stream_faker.hpp b/src/communication/result_stream_faker.hpp index 7a0a919b3..f52d8f3bf 100644 --- a/src/communication/result_stream_faker.hpp +++ b/src/communication/result_stream_faker.hpp @@ -37,8 +37,7 @@ class ResultStreamFaker { results_.push_back(values); } - void Summary( - const std::map<std::string, communication::bolt::DecodedValue> &summary) { + void Summary(const std::map<std::string, TResultValue> &summary) { DCHECK(current_state_ != State::Done) << "Can only send a summary once"; summary_ = summary; current_state_ = State::Done; @@ -136,5 +135,5 @@ class ResultStreamFaker { // the data that the record stream can accept std::vector<std::string> header_; std::vector<std::vector<TResultValue>> results_; - std::map<std::string, communication::bolt::DecodedValue> summary_; + std::map<std::string, TResultValue> summary_; }; diff --git a/src/database/state_delta.cpp b/src/database/state_delta.cpp index 04972e688..817017583 100644 --- a/src/database/state_delta.cpp +++ b/src/database/state_delta.cpp @@ -3,8 +3,8 @@ #include <string> #include "communication/bolt/v1/decoder/decoded_value.hpp" -#include "communication/conversion.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" namespace database { @@ -208,13 +208,13 @@ void StateDelta::Encode( encoder.WriteInt(vertex_id); encoder.WriteInt(property.Id()); encoder.WriteString(property_name); - encoder.WriteDecodedValue(communication::ToDecodedValue(value)); + encoder.WriteDecodedValue(glue::ToDecodedValue(value)); break; case Type::SET_PROPERTY_EDGE: encoder.WriteInt(edge_id); encoder.WriteInt(property.Id()); encoder.WriteString(property_name); - encoder.WriteDecodedValue(communication::ToDecodedValue(value)); + encoder.WriteDecodedValue(glue::ToDecodedValue(value)); break; case Type::ADD_LABEL: case Type::REMOVE_LABEL: @@ -304,14 +304,14 @@ std::experimental::optional<StateDelta> StateDelta::Decode( DECODE_MEMBER_CAST(property, ValueInt, storage::Property) DECODE_MEMBER(property_name, ValueString) if (!decoder.ReadValue(&dv)) return nullopt; - r_val.value = communication::ToPropertyValue(dv); + r_val.value = glue::ToPropertyValue(dv); break; case Type::SET_PROPERTY_EDGE: DECODE_MEMBER(edge_id, ValueInt) DECODE_MEMBER_CAST(property, ValueInt, storage::Property) DECODE_MEMBER(property_name, ValueString) if (!decoder.ReadValue(&dv)) return nullopt; - r_val.value = communication::ToPropertyValue(dv); + r_val.value = glue::ToPropertyValue(dv); break; case Type::ADD_LABEL: case Type::REMOVE_LABEL: diff --git a/src/durability/recovery.cpp b/src/durability/recovery.cpp index ddab0b15e..a77de5228 100644 --- a/src/durability/recovery.cpp +++ b/src/durability/recovery.cpp @@ -4,7 +4,6 @@ #include <limits> #include <unordered_map> -#include "communication/conversion.hpp" #include "database/graph_db_accessor.hpp" #include "database/indexes/label_property_index.hpp" #include "durability/hashed_file_reader.hpp" @@ -13,6 +12,7 @@ #include "durability/snapshot_decoder.hpp" #include "durability/version.hpp" #include "durability/wal.hpp" +#include "glue/conversion.hpp" #include "query/typed_value.hpp" #include "storage/address_types.hpp" #include "transactions/type.hpp" @@ -127,15 +127,13 @@ bool RecoverSnapshot(const fs::path &snapshot_file, database::GraphDb &db, auto vertex = decoder.ReadSnapshotVertex(); RETURN_IF_NOT(vertex); - auto vertex_accessor = - dba.InsertVertex(vertex->gid, vertex->cypher_id); + auto vertex_accessor = dba.InsertVertex(vertex->gid, vertex->cypher_id); for (const auto &label : vertex->labels) { vertex_accessor.add_label(dba.Label(label)); } for (const auto &property_pair : vertex->properties) { - vertex_accessor.PropsSet( - dba.Property(property_pair.first), - communication::ToTypedValue(property_pair.second)); + vertex_accessor.PropsSet(dba.Property(property_pair.first), + glue::ToTypedValue(property_pair.second)); } auto vertex_record = vertex_accessor.GetNew(); for (const auto &edge : vertex->in) { @@ -205,7 +203,7 @@ bool RecoverSnapshot(const fs::path &snapshot_file, database::GraphDb &db, for (const auto &property_pair : edge.properties) edge_accessor.PropsSet(dba.Property(property_pair.first), - communication::ToTypedValue(property_pair.second)); + glue::ToTypedValue(property_pair.second)); } // Vertex and edge counts are included in the hash. Re-read them to update the diff --git a/src/durability/snapshooter.cpp b/src/durability/snapshooter.cpp index 8258ec8d8..af7624498 100644 --- a/src/durability/snapshooter.cpp +++ b/src/durability/snapshooter.cpp @@ -67,7 +67,7 @@ bool Encode(const fs::path &snapshot_file, database::GraphDb &db, vertex_num++; } for (const auto &edge : dba.Edges(false)) { - encoder.WriteEdge(communication::ToDecodedEdge(edge)); + encoder.WriteEdge(glue::ToDecodedEdge(edge)); encoder.WriteInt(edge.cypher_id()); edge_num++; } diff --git a/src/durability/snapshot_encoder.hpp b/src/durability/snapshot_encoder.hpp index 8edd284d3..9180c747f 100644 --- a/src/durability/snapshot_encoder.hpp +++ b/src/durability/snapshot_encoder.hpp @@ -1,8 +1,8 @@ #pragma once #include "communication/bolt/v1/encoder/base_encoder.hpp" -#include "communication/conversion.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" #include "utils/cast.hpp" namespace durability { @@ -14,7 +14,7 @@ class SnapshotEncoder : public communication::bolt::BaseEncoder<Buffer> { : communication::bolt::BaseEncoder<Buffer>(buffer) {} void WriteSnapshotVertex(const VertexAccessor &vertex) { communication::bolt::BaseEncoder<Buffer>::WriteVertex( - communication::ToDecodedVertex(vertex)); + glue::ToDecodedVertex(vertex)); // Write cypher_id this->WriteInt(vertex.cypher_id()); diff --git a/src/communication/conversion.cpp b/src/glue/conversion.cpp similarity index 98% rename from src/communication/conversion.cpp rename to src/glue/conversion.cpp index 26211c2ba..1411f0574 100644 --- a/src/communication/conversion.cpp +++ b/src/glue/conversion.cpp @@ -1,4 +1,4 @@ -#include "communication/conversion.hpp" +#include "glue/conversion.hpp" #include <map> #include <string> @@ -8,7 +8,7 @@ using communication::bolt::DecodedValue; -namespace communication { +namespace glue { query::TypedValue ToTypedValue(const DecodedValue &value) { switch (value.type()) { @@ -189,4 +189,4 @@ DecodedValue ToDecodedValue(const PropertyValue &value) { } } -} // namespace communication +} // namespace glue diff --git a/src/communication/conversion.hpp b/src/glue/conversion.hpp similarity index 93% rename from src/communication/conversion.hpp rename to src/glue/conversion.hpp index 2c16facb1..9f2ff8ab1 100644 --- a/src/communication/conversion.hpp +++ b/src/glue/conversion.hpp @@ -5,7 +5,7 @@ #include "query/typed_value.hpp" #include "storage/property_value.hpp" -namespace communication { +namespace glue { communication::bolt::DecodedVertex ToDecodedVertex( const VertexAccessor &vertex); @@ -23,4 +23,4 @@ communication::bolt::DecodedValue ToDecodedValue(const PropertyValue &value); PropertyValue ToPropertyValue(const communication::bolt::DecodedValue &value); -} // namespace communication +} // namespace glue diff --git a/src/memgraph_bolt.cpp b/src/memgraph_bolt.cpp index 56b231c6a..af850020a 100644 --- a/src/memgraph_bolt.cpp +++ b/src/memgraph_bolt.cpp @@ -14,6 +14,10 @@ #include "communication/bolt/v1/session.hpp" #include "config.hpp" #include "database/graph_db.hpp" +#include "distributed/pull_rpc_clients.hpp" +#include "glue/conversion.hpp" +#include "query/exceptions.hpp" +#include "query/interpreter.hpp" #include "stats/stats.hpp" #include "telemetry/telemetry.hpp" #include "utils/flag_validation.hpp" @@ -24,12 +28,6 @@ // Common stuff for enterprise and community editions -using communication::bolt::SessionData; -using SessionT = communication::bolt::Session<communication::InputStream, - communication::OutputStream>; -using ServerT = communication::Server<SessionT, SessionData>; -using communication::ServerContext; - // General purpose flags. DEFINE_string(interface, "0.0.0.0", "Communication interface on which to listen."); @@ -59,6 +57,173 @@ DEFINE_bool(telemetry_enabled, false, "to allow for easier improvement of the product."); DECLARE_string(durability_directory); +/** Encapsulates Dbms and Interpreter that are passed through the network server + * and worker to the session. */ +struct SessionData { + database::MasterBase &db; + query::Interpreter interpreter{db}; +}; + +class BoltSession final + : public communication::bolt::Session<communication::InputStream, + communication::OutputStream> { + public: + BoltSession(SessionData &data, communication::InputStream &input_stream, + communication::OutputStream &output_stream) + : communication::bolt::Session<communication::InputStream, + communication::OutputStream>( + input_stream, output_stream), + db_(data.db), + interpreter_(data.interpreter) {} + + ~BoltSession() { + if (db_accessor_) { + Abort(); + } + } + + bool IsShuttingDown() override { return !db_.is_accepting_transactions(); } + + using communication::bolt::Session< + communication::InputStream, communication::OutputStream>::ResultStreamT; + + void PullAll( + const std::string &query, + const std::map<std::string, communication::bolt::DecodedValue> ¶ms, + ResultStreamT *result_stream) override { + bool in_explicit_transaction = !!db_accessor_; + if (!db_accessor_) + db_accessor_ = std::make_unique<database::GraphDbAccessor>(db_); + // TODO: Queries below should probably move to interpreter, but there is + // only one interpreter in GraphDb. We probably need some kind of + // per-session access for the interpreter. + // TODO: Also, write tests for these queries + if (expect_rollback_ && query != "ROLLBACK") { + // Client could potentially recover if we move to error state, but we + // don't implement rollback of single command in transaction, only + // rollback of whole transaction so we can't continue in this transaction + // if we receive new query. + throw communication::bolt::ClientError( + "Expected ROLLBACK, because previous query contained an error"); + } + if (query == "ROLLBACK") { + if (!in_explicit_transaction) + throw communication::bolt::ClientError( + "ROLLBACK can only be used after BEGIN"); + Abort(); + result_stream->Header({}); + result_stream->Summary({}); + return; + } else if (query == "BEGIN") { + if (in_explicit_transaction) + throw communication::bolt::ClientError("BEGIN already called"); + // We accept BEGIN, so send the empty results. + result_stream->Header({}); + result_stream->Summary({}); + return; + } else if (query == "COMMIT") { + if (!in_explicit_transaction) + throw communication::bolt::ClientError( + "COMMIT can only be used after BEGIN"); + Commit(); + result_stream->Header({}); + result_stream->Summary({}); + return; + } + // Any other query in BEGIN block advances the command. + if (in_explicit_transaction) AdvanceCommand(); + // Handle regular Cypher queries below + std::map<std::string, query::TypedValue> params_tv; + for (const auto &kv : params) + params_tv.emplace(kv.first, glue::ToTypedValue(kv.second)); + auto abort_tx = [this, in_explicit_transaction]() { + if (in_explicit_transaction) + expect_rollback_ = true; + else + this->Abort(); + }; + try { + TypedValueResultStream stream(result_stream); + interpreter_(query, *db_accessor_, params_tv, in_explicit_transaction) + .PullAll(stream); + if (!in_explicit_transaction) Commit(); + } catch (const query::QueryException &e) { + abort_tx(); + // Wrap QueryException into ClientError, because we want to allow the + // client to fix their query. + throw communication::bolt::ClientError(e.what()); + } catch (const utils::BasicException &) { + abort_tx(); + throw; + } + } + + void Abort() override { + if (!db_accessor_) return; + db_accessor_->Abort(); + db_accessor_ = nullptr; + } + + private: + // Wrapper around ResultStreamT which converts TypedValue to DecodedValue + // before forwarding the calls to original ResultStreamT. + class TypedValueResultStream { + public: + TypedValueResultStream(ResultStreamT *result_stream) + : result_stream_(result_stream) {} + + void Header(const std::vector<std::string> &fields) { + return result_stream_->Header(fields); + } + + void Result(const std::vector<query::TypedValue> &values) { + std::vector<communication::bolt::DecodedValue> decoded_values; + decoded_values.reserve(values.size()); + for (const auto &v : values) { + decoded_values.push_back(glue::ToDecodedValue(v)); + } + return result_stream_->Result(decoded_values); + } + + void Summary(const std::map<std::string, query::TypedValue> &summary) { + std::map<std::string, communication::bolt::DecodedValue> decoded_summary; + for (const auto &kv : summary) { + decoded_summary.emplace(kv.first, glue::ToDecodedValue(kv.second)); + } + return result_stream_->Summary(decoded_summary); + } + + private: + ResultStreamT *result_stream_; + }; + + database::MasterBase &db_; + query::Interpreter &interpreter_; + // GraphDbAccessor of active transaction in the session, can be null if + // there is no associated transaction. + std::unique_ptr<database::GraphDbAccessor> db_accessor_; + bool expect_rollback_{false}; + + void Commit() { + DCHECK(db_accessor_) << "Commit called and there is no transaction"; + db_accessor_->Commit(); + db_accessor_ = nullptr; + } + + void AdvanceCommand() { + db_accessor_->AdvanceCommand(); + if (db_.type() == database::GraphDb::Type::DISTRIBUTED_MASTER) { + auto tx_id = db_accessor_->transaction_id(); + auto futures = + db_.pull_clients().NotifyAllTransactionCommandAdvanced(tx_id); + for (auto &future : futures) future.wait(); + } + } +}; + +using ServerT = communication::Server<BoltSession, SessionData>; +using communication::ServerContext; + // Needed to correctly handle memgraph destruction from a signal handler. // Without having some sort of a flag, it is possible that a signal is handled // when we are exiting main, inside destructors of database::GraphDb and diff --git a/src/query/interpreter.cpp b/src/query/interpreter.cpp index 74a4be1fc..5c5df863b 100644 --- a/src/query/interpreter.cpp +++ b/src/query/interpreter.cpp @@ -98,7 +98,7 @@ Interpreter::Results Interpreter::operator()( ctx.symbol_table_ = plan->symbol_table(); - std::map<std::string, communication::bolt::DecodedValue> summary; + std::map<std::string, TypedValue> summary; summary["parsing_time"] = frontend_time.count(); summary["planning_time"] = planning_time.count(); summary["cost_estimate"] = plan->cost(); diff --git a/src/query/interpreter.hpp b/src/query/interpreter.hpp index db90fe1af..fae3b95bb 100644 --- a/src/query/interpreter.hpp +++ b/src/query/interpreter.hpp @@ -2,8 +2,6 @@ #include <gflags/gflags.h> -#include "communication/bolt/v1/encoder/base_encoder.hpp" -#include "communication/conversion.hpp" #include "data_structures/concurrent/concurrent_map.hpp" #include "database/graph_db.hpp" #include "database/graph_db_accessor.hpp" @@ -68,8 +66,7 @@ class Interpreter { Results(Context ctx, std::shared_ptr<CachedPlan> plan, std::unique_ptr<query::plan::Cursor> cursor, std::vector<Symbol> output_symbols, std::vector<std::string> header, - std::map<std::string, communication::bolt::DecodedValue> summary, - PlanCacheT &plan_cache) + std::map<std::string, TypedValue> summary, PlanCacheT &plan_cache) : ctx_(std::move(ctx)), plan_(plan), cursor_(std::move(cursor)), @@ -144,7 +141,7 @@ class Interpreter { bool header_written_{false}; std::vector<std::string> header_; - std::map<std::string, communication::bolt::DecodedValue> summary_; + std::map<std::string, TypedValue> summary_; utils::Timer execution_timer_; // Gets invalidated after if an index has been built. diff --git a/src/storage/pod_buffer.hpp b/src/storage/pod_buffer.hpp index 45c17203e..0326331e4 100644 --- a/src/storage/pod_buffer.hpp +++ b/src/storage/pod_buffer.hpp @@ -1,6 +1,9 @@ #pragma once -#include "communication/bolt/v1/encoder/base_encoder.hpp" +#include <cstdint> +#include <cstring> +#include <string> +#include <vector> namespace storage { diff --git a/src/storage/property_value_store.cpp b/src/storage/property_value_store.cpp index 4f6d5d70e..1c5720db1 100644 --- a/src/storage/property_value_store.cpp +++ b/src/storage/property_value_store.cpp @@ -4,7 +4,8 @@ #include "glog/logging.h" #include "communication/bolt/v1/decoder/decoder.hpp" -#include "communication/conversion.hpp" +#include "communication/bolt/v1/encoder/base_encoder.hpp" +#include "glue/conversion.hpp" #include "storage/pod_buffer.hpp" #include "storage/property_value_store.hpp" @@ -213,7 +214,7 @@ PropertyValueStore::iterator PropertyValueStore::end() const { std::string PropertyValueStore::SerializeProp(const PropertyValue &prop) const { storage::PODBuffer pod_buffer; BaseEncoder<storage::PODBuffer> encoder{pod_buffer}; - encoder.WriteDecodedValue(communication::ToDecodedValue(prop)); + encoder.WriteDecodedValue(glue::ToDecodedValue(prop)); return std::string(reinterpret_cast<char *>(pod_buffer.buffer.data()), pod_buffer.buffer.size()); } @@ -228,7 +229,7 @@ PropertyValue PropertyValueStore::DeserializeProp( DLOG(WARNING) << "Unable to read property value"; return PropertyValue::Null; } - return communication::ToPropertyValue(dv); + return glue::ToPropertyValue(dv); } storage::KVStore PropertyValueStore::ConstructDiskStorage() const { diff --git a/tests/manual/CMakeLists.txt b/tests/manual/CMakeLists.txt index aab46eb77..b53e414b3 100644 --- a/tests/manual/CMakeLists.txt +++ b/tests/manual/CMakeLists.txt @@ -28,7 +28,7 @@ add_manual_test(binomial.cpp) target_link_libraries(${test_prefix}binomial mg-utils) add_manual_test(bolt_client.cpp) -target_link_libraries(${test_prefix}bolt_client memgraph_lib kvstore_dummy_lib) +target_link_libraries(${test_prefix}bolt_client mg-communication) add_manual_test(card_fraud_generate_snapshot.cpp) target_link_libraries(${test_prefix}card_fraud_generate_snapshot memgraph_lib kvstore_dummy_lib) @@ -72,10 +72,10 @@ add_manual_test(stripped_timing.cpp) target_link_libraries(${test_prefix}stripped_timing memgraph_lib kvstore_dummy_lib) add_manual_test(ssl_client.cpp) -target_link_libraries(${test_prefix}ssl_client memgraph_lib kvstore_dummy_lib) +target_link_libraries(${test_prefix}ssl_client mg-communication) add_manual_test(ssl_server.cpp) -target_link_libraries(${test_prefix}ssl_server memgraph_lib kvstore_dummy_lib) +target_link_libraries(${test_prefix}ssl_server mg-communication) add_manual_test(xorshift.cpp) target_link_libraries(${test_prefix}xorshift mg-utils) diff --git a/tests/manual/distributed_common.hpp b/tests/manual/distributed_common.hpp index 0ac0d1729..676d90d90 100644 --- a/tests/manual/distributed_common.hpp +++ b/tests/manual/distributed_common.hpp @@ -3,9 +3,9 @@ #include <chrono> #include <vector> -#include "communication/conversion.hpp" #include "communication/result_stream_faker.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" #include "query/interpreter.hpp" #include "query/typed_value.hpp" diff --git a/tests/manual/snapshot_generation/snapshot_writer.hpp b/tests/manual/snapshot_generation/snapshot_writer.hpp index df8881ecd..26e190ef0 100644 --- a/tests/manual/snapshot_generation/snapshot_writer.hpp +++ b/tests/manual/snapshot_generation/snapshot_writer.hpp @@ -4,10 +4,10 @@ #include <vector> #include "communication/bolt/v1/encoder/base_encoder.hpp" -#include "communication/conversion.hpp" #include "durability/hashed_file_writer.hpp" #include "durability/paths.hpp" #include "durability/version.hpp" +#include "glue/conversion.hpp" #include "query/typed_value.hpp" #include "utils/file.hpp" @@ -71,7 +71,7 @@ class SnapshotWriter { WriteList(node.labels); std::map<std::string, communication::bolt::DecodedValue> props; for (const auto &prop : node.props) { - props[prop.first] = communication::ToDecodedValue(prop.second); + props[prop.first] = glue::ToDecodedValue(prop.second); } encoder_.WriteMap(props); @@ -102,7 +102,7 @@ class SnapshotWriter { encoder_.WriteString(edge.type); std::map<std::string, communication::bolt::DecodedValue> props; for (const auto &prop : edge.props) { - props[prop.first] = communication::ToDecodedValue(prop.second); + props[prop.first] = glue::ToDecodedValue(prop.second); } encoder_.WriteMap(props); diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 466de9102..f2fef0b19 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -19,27 +19,9 @@ function(add_unit_test test_cpp) add_dependencies(memgraph__unit ${target_name}) endfunction(add_unit_test) -add_unit_test(bolt_chunked_decoder_buffer.cpp) -target_link_libraries(${test_prefix}bolt_chunked_decoder_buffer memgraph_lib kvstore_dummy_lib) - -add_unit_test(bolt_chunked_encoder_buffer.cpp) -target_link_libraries(${test_prefix}bolt_chunked_encoder_buffer memgraph_lib kvstore_dummy_lib) - -add_unit_test(bolt_decoder.cpp) -target_link_libraries(${test_prefix}bolt_decoder memgraph_lib kvstore_dummy_lib) - add_unit_test(bolt_encoder.cpp) target_link_libraries(${test_prefix}bolt_encoder memgraph_lib kvstore_dummy_lib) -add_unit_test(bolt_result_stream.cpp) -target_link_libraries(${test_prefix}bolt_result_stream memgraph_lib kvstore_dummy_lib) - -add_unit_test(bolt_session.cpp) -target_link_libraries(${test_prefix}bolt_session memgraph_lib kvstore_dummy_lib) - -add_unit_test(communication_buffer.cpp) -target_link_libraries(${test_prefix}communication_buffer memgraph_lib kvstore_dummy_lib) - add_unit_test(concurrent_id_mapper_distributed.cpp) target_link_libraries(${test_prefix}concurrent_id_mapper_distributed memgraph_lib kvstore_dummy_lib) @@ -157,9 +139,6 @@ target_link_libraries(${test_prefix}mvcc_one_transaction memgraph_lib kvstore_du add_unit_test(mvcc_parallel_update.cpp) target_link_libraries(${test_prefix}mvcc_parallel_update memgraph_lib kvstore_dummy_lib) -add_unit_test(network_timeouts.cpp) -target_link_libraries(${test_prefix}network_timeouts memgraph_lib kvstore_dummy_lib) - add_unit_test(pod_buffer.cpp) target_link_libraries(${test_prefix}pod_buffer memgraph_lib kvstore_dummy_lib) @@ -208,9 +187,6 @@ target_link_libraries(${test_prefix}raft_storage memgraph_lib kvstore_dummy_lib) add_unit_test(record_edge_vertex_accessor.cpp) target_link_libraries(${test_prefix}record_edge_vertex_accessor memgraph_lib kvstore_dummy_lib) -add_unit_test(rpc.cpp) -target_link_libraries(${test_prefix}rpc memgraph_lib kvstore_dummy_lib) - add_unit_test(rpc_worker_clients.cpp) target_link_libraries(${test_prefix}rpc_worker_clients memgraph_lib kvstore_dummy_lib) @@ -253,6 +229,32 @@ target_link_libraries(${test_prefix}transaction_engine_single_node memgraph_lib add_unit_test(typed_value.cpp) target_link_libraries(${test_prefix}typed_value memgraph_lib kvstore_dummy_lib) +# Test mg-communication + +add_unit_test(bolt_chunked_decoder_buffer.cpp) +target_link_libraries(${test_prefix}bolt_chunked_decoder_buffer mg-communication) + +add_unit_test(bolt_chunked_encoder_buffer.cpp) +target_link_libraries(${test_prefix}bolt_chunked_encoder_buffer mg-communication) + +add_unit_test(bolt_decoder.cpp) +target_link_libraries(${test_prefix}bolt_decoder mg-communication) + +add_unit_test(bolt_result_stream.cpp) +target_link_libraries(${test_prefix}bolt_result_stream mg-communication) + +add_unit_test(bolt_session.cpp) +target_link_libraries(${test_prefix}bolt_session mg-communication) + +add_unit_test(communication_buffer.cpp) +target_link_libraries(${test_prefix}communication_buffer mg-communication) + +add_unit_test(network_timeouts.cpp) +target_link_libraries(${test_prefix}network_timeouts mg-communication) + +add_unit_test(rpc.cpp) +target_link_libraries(${test_prefix}rpc mg-communication) + # Test data structures add_unit_test(ring_buffer.cpp) diff --git a/tests/unit/bolt_common.hpp b/tests/unit/bolt_common.hpp index 199c5b1f0..47dd75b3a 100644 --- a/tests/unit/bolt_common.hpp +++ b/tests/unit/bolt_common.hpp @@ -6,8 +6,7 @@ #include <vector> #include <glog/logging.h> - -#include "gtest/gtest.h" +#include <gtest/gtest.h> /** * TODO (mferencevic): document diff --git a/tests/unit/bolt_decoder.cpp b/tests/unit/bolt_decoder.cpp index ea3b3d756..64d062673 100644 --- a/tests/unit/bolt_decoder.cpp +++ b/tests/unit/bolt_decoder.cpp @@ -4,7 +4,6 @@ #include "bolt_testdata.hpp" #include "communication/bolt/v1/decoder/decoder.hpp" -#include "query/typed_value.hpp" using communication::bolt::DecodedValue; diff --git a/tests/unit/bolt_encoder.cpp b/tests/unit/bolt_encoder.cpp index 598ec6a04..ee687cf2e 100644 --- a/tests/unit/bolt_encoder.cpp +++ b/tests/unit/bolt_encoder.cpp @@ -2,9 +2,9 @@ #include "bolt_testdata.hpp" #include "communication/bolt/v1/encoder/encoder.hpp" -#include "communication/conversion.hpp" #include "database/graph_db.hpp" #include "database/graph_db_accessor.hpp" +#include "glue/conversion.hpp" using communication::bolt::DecodedValue; @@ -189,9 +189,9 @@ TEST(BoltEncoder, VertexAndEdge) { // check everything std::vector<DecodedValue> vals; - vals.push_back(communication::ToDecodedValue(va1)); - vals.push_back(communication::ToDecodedValue(va2)); - vals.push_back(communication::ToDecodedValue(ea)); + vals.push_back(glue::ToDecodedValue(va1)); + vals.push_back(glue::ToDecodedValue(va2)); + vals.push_back(glue::ToDecodedValue(ea)); bolt_encoder.MessageRecord(vals); // The vertexedge_encoded testdata has hardcoded zeros for IDs, diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 5cbb8dea5..f8eeb2f9c 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -2,27 +2,60 @@ #include <glog/logging.h> #include "bolt_common.hpp" -#include "communication/bolt/v1/encoder/result_stream.hpp" #include "communication/bolt/v1/session.hpp" -#include "database/graph_db.hpp" + +using communication::bolt::ClientError; +using communication::bolt::DecodedValue; +using communication::bolt::Session; +using communication::bolt::SessionException; +using communication::bolt::State; + +static const char *kInvalidQuery = "invalid query"; +static const char *kQueryReturn42 = "RETURN 42"; +static const char *kQueryEmpty = "no results"; + +class TestSessionData {}; + +class TestSession : public Session<TestInputStream, TestOutputStream> { + public: + using Session<TestInputStream, TestOutputStream>::ResultStreamT; + + TestSession(TestSessionData &data, TestInputStream &input_stream, + TestOutputStream &output_stream) + : Session<TestInputStream, TestOutputStream>(input_stream, + output_stream) {} + + bool IsShuttingDown() override { return false; } + + void PullAll(const std::string &query, + const std::map<std::string, DecodedValue> ¶ms, + ResultStreamT *result_stream) override { + if (query == kQueryReturn42) { + result_stream->Header({"result_name"}); + result_stream->Result(std::vector<DecodedValue>{42}); + result_stream->Summary({}); + } else if (query == kQueryEmpty) { + result_stream->Header({"result_name"}); + result_stream->Summary({}); + } else { + throw ClientError("client sent invalid query"); + } + } + + void Abort() override {} +}; + +using ResultStreamT = TestSession::ResultStreamT; // TODO: This could be done in fixture. // Shortcuts for writing variable initializations in tests -#define INIT_VARS \ - TestInputStream input_stream; \ - TestOutputStream output_stream; \ - database::SingleNode db; \ - SessionData session_data{db}; \ - SessionT session(session_data, input_stream, output_stream); \ +#define INIT_VARS \ + TestInputStream input_stream; \ + TestOutputStream output_stream; \ + TestSessionData session_data; \ + TestSession session(session_data, input_stream, output_stream); \ std::vector<uint8_t> &output = output_stream.output; -using communication::bolt::SessionData; -using communication::bolt::SessionException; -using communication::bolt::State; -using SessionT = - communication::bolt::Session<TestInputStream, TestOutputStream>; -using ResultStreamT = SessionT::ResultStreamT; - // Sample testdata that has correct inputs and outputs. const uint8_t handshake_req[] = {0x60, 0x60, 0xb0, 0x17, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -85,7 +118,7 @@ void CheckIgnoreMessage(std::vector<uint8_t> &output) { } // Execute and check a correct handshake -void ExecuteHandshake(TestInputStream &input_stream, SessionT &session, +void ExecuteHandshake(TestInputStream &input_stream, TestSession &session, std::vector<uint8_t> &output) { input_stream.Write(handshake_req, 20); session.Execute(); @@ -95,7 +128,7 @@ void ExecuteHandshake(TestInputStream &input_stream, SessionT &session, } // Write bolt chunk and execute command -void ExecuteCommand(TestInputStream &input_stream, SessionT &session, +void ExecuteCommand(TestInputStream &input_stream, TestSession &session, const uint8_t *data, size_t len, bool chunk = true) { if (chunk) WriteChunkHeader(input_stream, len); input_stream.Write(data, len); @@ -104,7 +137,7 @@ void ExecuteCommand(TestInputStream &input_stream, SessionT &session, } // Execute and check a correct init -void ExecuteInit(TestInputStream &input_stream, SessionT &session, +void ExecuteInit(TestInputStream &input_stream, TestSession &session, std::vector<uint8_t> &output) { ExecuteCommand(input_stream, session, init_req, sizeof(init_req)); ASSERT_EQ(session.state_, State::Idle); @@ -277,7 +310,7 @@ TEST(BoltSession, ExecuteRunBasicException) { ExecuteInit(input_stream, session, output); output_stream.SetWriteSuccess(i == 0); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); if (i == 0) { session.Execute(); } else { @@ -285,7 +318,7 @@ TEST(BoltSession, ExecuteRunBasicException) { } if (i == 0) { - ASSERT_EQ(session.state_, State::ErrorIdle); + ASSERT_EQ(session.state_, State::Error); CheckFailureMessage(output); } else { ASSERT_EQ(session.state_, State::Close); @@ -300,7 +333,7 @@ TEST(BoltSession, ExecuteRunWithoutPullAll) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "RETURN 2"); + WriteRunRequest(input_stream, kQueryReturn42); session.Execute(); ASSERT_EQ(session.state_, State::Result); @@ -361,7 +394,7 @@ TEST(BoltSession, ExecutePullAllDiscardAllReset) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "CREATE (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); session.Execute(); if (j == 1) output.clear(); @@ -407,7 +440,7 @@ TEST(BoltSession, ErrorIgnoreMessage) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -425,7 +458,7 @@ TEST(BoltSession, ErrorIgnoreMessage) { ASSERT_EQ(session.decoder_buffer_.Size(), 0); if (i == 0) { - ASSERT_EQ(session.state_, State::ErrorIdle); + ASSERT_EQ(session.state_, State::Error); CheckOutput(output, ignored_resp, sizeof(ignored_resp)); } else { ASSERT_EQ(session.state_, State::Close); @@ -441,7 +474,7 @@ TEST(BoltSession, ErrorRunAfterRun) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); session.Execute(); output.clear(); @@ -452,7 +485,7 @@ TEST(BoltSession, ErrorRunAfterRun) { ASSERT_EQ(session.state_, State::Result); // New run request. - WriteRunRequest(input_stream, "MATCH (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); ASSERT_THROW(session.Execute(), SessionException); ASSERT_EQ(session.state_, State::Close); @@ -464,7 +497,7 @@ TEST(BoltSession, ErrorCantCleanup) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -484,7 +517,7 @@ TEST(BoltSession, ErrorWrongMarker) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -510,7 +543,7 @@ TEST(BoltSession, ErrorOK) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -543,7 +576,7 @@ TEST(BoltSession, ErrorMissingData) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "MATCH (omnom"); + WriteRunRequest(input_stream, kInvalidQuery); session.Execute(); output.clear(); @@ -563,7 +596,7 @@ TEST(BoltSession, MultipleChunksInOneExecute) { ExecuteHandshake(input_stream, session, output); ExecuteInit(input_stream, session, output); - WriteRunRequest(input_stream, "CREATE (n) RETURN n"); + WriteRunRequest(input_stream, kQueryReturn42); ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); ASSERT_EQ(session.state_, State::Idle); @@ -607,117 +640,6 @@ TEST(BoltSession, PartialChunk) { PrintOutput(output); } -TEST(BoltSession, ExplicitTransactionValidQueries) { - // It is not really easy to check if we commited or aborted transaction except - // by faking GraphDb/TxEngine... - std::vector<std::string> transaction_ends = {"COMMIT", "ROLLBACK"}; - - for (const auto &transaction_end : transaction_ends) { - INIT_VARS; - - ExecuteHandshake(input_stream, session, output); - ExecuteInit(input_stream, session, output); - - WriteRunRequest(input_stream, "BEGIN"); - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, "MATCH (n) RETURN n"); - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, transaction_end.c_str()); - session.Execute(); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - ASSERT_EQ(session.state_, State::Result); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - } -} - -TEST(BoltSession, ExplicitTransactionInvalidQuery) { - std::vector<std::string> transaction_ends = {"COMMIT", "ROLLBACK"}; - - for (const auto &transaction_end : transaction_ends) { - INIT_VARS; - - ExecuteHandshake(input_stream, session, output); - ExecuteInit(input_stream, session, output); - - WriteRunRequest(input_stream, "BEGIN"); - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, "MATCH ("); - session.Execute(); - ASSERT_EQ(session.state_, State::ErrorWaitForRollback); - ASSERT_TRUE(session.db_accessor_); - CheckFailureMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::ErrorWaitForRollback); - ASSERT_TRUE(session.db_accessor_); - CheckIgnoreMessage(output); - - ExecuteCommand(input_stream, session, ackfailure_req, - sizeof(ackfailure_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::WaitForRollback); - ASSERT_TRUE(session.db_accessor_); - CheckSuccessMessage(output); - - WriteRunRequest(input_stream, transaction_end.c_str()); - - if (transaction_end == "ROLLBACK") { - session.Execute(); - ASSERT_EQ(session.state_, State::Result); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - - ExecuteCommand(input_stream, session, pullall_req, sizeof(pullall_req)); - session.Execute(); - ASSERT_EQ(session.state_, State::Idle); - ASSERT_FALSE(session.db_accessor_); - CheckSuccessMessage(output); - - } else { - ASSERT_THROW(session.Execute(), SessionException); - ASSERT_EQ(session.state_, State::Close); - CheckFailureMessage(output); - } - } -} - int main(int argc, char **argv) { google::InitGoogleLogging(argv[0]); ::testing::InitGoogleTest(&argc, argv);