From da0e4a5b12362d5c757901284eba144f7c00b70b Mon Sep 17 00:00:00 2001 From: Mislav Bradac Date: Thu, 3 Aug 2017 15:53:41 +0200 Subject: [PATCH] Implement explicitly started transactions Summary: Fix tests Reviewers: buda, mferencevic Reviewed By: mferencevic Subscribers: mferencevic, pullbot Differential Revision: https://phabricator.memgraph.io/D623 --- src/communication/bolt/v1/session.hpp | 37 +++- src/communication/bolt/v1/state.hpp | 10 +- src/communication/bolt/v1/states/error.hpp | 17 +- .../bolt/v1/states/idle_result.hpp | 194 +++++++++++------- tests/drivers/python/transactions.py | 36 ++++ tests/unit/bolt_session.cpp | 156 ++++++++++++-- 6 files changed, 359 insertions(+), 91 deletions(-) create mode 100644 tests/drivers/python/transactions.py diff --git a/src/communication/bolt/v1/session.hpp b/src/communication/bolt/v1/session.hpp index bb557cb64..d00c7580e 100644 --- a/src/communication/bolt/v1/session.hpp +++ b/src/communication/bolt/v1/session.hpp @@ -42,6 +42,11 @@ class Session { event_.data.ptr = this; } + ~Session() { + debug_assert(!db_accessor_, + "Transaction should have already be closed in Close"); + } + /** * @return is the session in a valid state */ @@ -90,10 +95,11 @@ class Session { break; case State::Idle: case State::Result: - state_ = StateIdleResultRun(*this, state_); + case State::WaitForRollback: + state_ = StateExecutingRun(*this, state_); break; case State::ErrorIdle: - case State::ErrorResult: + case State::ErrorWaitForRollback: state_ = StateErrorRun(*this, state_); break; case State::Close: @@ -136,9 +142,30 @@ class Session { */ void Close() { DLOG(INFO) << "Closing session"; + if (db_accessor_) { + Abort(); + } this->socket_.Close(); } + /** + * Commits associated transaction. + */ + void Commit() { + debug_assert(db_accessor_, "Commit called and there is no transaction"); + db_accessor_->commit(); + db_accessor_ = nullptr; + } + + /** + * Aborts associated transaction. + */ + void Abort() { + debug_assert(db_accessor_, "Abort called and there is no transaction"); + db_accessor_->abort(); + db_accessor_ = nullptr; + } + GraphDbAccessor ActiveDb() { return dbms_.active(); } // TODO: Rethink if there is a way to hide some members. At the momemnt all of @@ -158,8 +185,9 @@ class Session { io::network::Epoll::Event event_; bool connected_{false}; State state_{State::Handshake}; - // Active transaction of the session, can be null. - tx::Transaction *transaction_; + // GraphDbAccessor of active transaction in the session, can be null if there + // is no associated transaction. + std::unique_ptr db_accessor_; private: void ClientFailureInvalidData() { @@ -167,6 +195,7 @@ class Session { state_ = State::Close; // don't care about the return status because this is always // called when we are about to close the connection to the client + encoder_buffer_.Clear(); encoder_.MessageFailure({{"code", "Memgraph.InvalidData"}, {"message", "The client has sent invalid data!"}}); // close the connection diff --git a/src/communication/bolt/v1/state.hpp b/src/communication/bolt/v1/state.hpp index 305889633..b848ca6cd 100644 --- a/src/communication/bolt/v1/state.hpp +++ b/src/communication/bolt/v1/state.hpp @@ -31,6 +31,12 @@ enum class State : uint8_t { */ Result, + /** + * There was an acked error in explicitly started transaction, now we are + * waiting for "ROLLBACK" in RUN command. + */ + WaitForRollback, + /** * This state handles errors, if client handles error response correctly next * state is Idle. @@ -39,9 +45,9 @@ enum class State : uint8_t { /** * This state handles errors, if client handles error response correctly next - * state is Result. + * state is WaitForRollback. */ - ErrorResult, + ErrorWaitForRollback, /** * This is a 'virtual' state (it doesn't have a run function) which tells diff --git a/src/communication/bolt/v1/states/error.hpp b/src/communication/bolt/v1/states/error.hpp index 7c99a3186..dfbf35366 100644 --- a/src/communication/bolt/v1/states/error.hpp +++ b/src/communication/bolt/v1/states/error.hpp @@ -27,26 +27,35 @@ State StateErrorRun(Session &session, State state) { DLOG(INFO) << fmt::format("Message signature is: 0x{:02X}", underlying_cast(signature)); - // clear the data buffer if it has any leftover data + // Clear the data buffer if it has any leftover data. session.encoder_buffer_.Clear(); if (signature == Signature::AckFailure || signature == Signature::Reset) { - if (signature == Signature::AckFailure) + if (signature == Signature::AckFailure) { DLOG(INFO) << "AckFailure received"; - else + } else { DLOG(INFO) << "Reset received"; + } if (!session.encoder_.MessageSuccess()) { DLOG(WARNING) << "Couldn't send success message!"; return State::Close; } if (signature == Signature::Reset) { + if (session.db_accessor_) { + session.Abort(); + } return State::Idle; } + + // We got AckFailure get back to right state. if (state == State::ErrorIdle) { return State::Idle; + } else if (state == State::ErrorWaitForRollback) { + return State::WaitForRollback; + } else { + permanent_assert(false, "Shouldn't happen"); } - return State::Result; } else { uint8_t value = underlying_cast(marker); diff --git a/src/communication/bolt/v1/states/idle_result.hpp b/src/communication/bolt/v1/states/idle_result.hpp index 58be3af2e..d1ec0662a 100644 --- a/src/communication/bolt/v1/states/idle_result.hpp +++ b/src/communication/bolt/v1/states/idle_result.hpp @@ -34,41 +34,93 @@ State HandleRun(Session &session, State state, Marker marker) { return State::Close; } - auto db_accessor = session.dbms_.active(); - DLOG(INFO) << fmt::format("[ActiveDB] '{}'", db_accessor->name()); + if (state == State::WaitForRollback) { + if (query.Value() == "ROLLBACK") { + session.Abort(); + // One MessageSuccess for RUN command should be flushed. + session.encoder_.MessageSuccess(); + // One for PULL_ALL should be chunked. + session.encoder_.MessageSuccess({}, false); + return State::Result; + } + DLOG(WARNING) << "Expected RUN \"ROLLBACK\" not received!"; + // Client could potentially recover if we move to error state, but we don't + // implement rollback of single command in transaction, only rollback of + // whole transaction so we can't continue in this transaction if we receive + // new RUN command. + return State::Close; + } if (state != State::Idle) { - // TODO: We shouldn't clear the buffer and move to ErrorIdle state, but send - // MessageFailure without sending data that is already in buffer and move to - // ErrorResult state. - session.encoder_buffer_.Clear(); - - // send failure message - bool unexpected_run_fail_sent = session.encoder_.MessageFailure( - {{"code", "Memgraph.QueryExecutionFail"}, - {"message", "Unexpected RUN command"}}); - + // Client could potentially recover if we move to error state, but there is + // no legitimate situation in which well working client would end up in this + // situation. DLOG(WARNING) << "Unexpected RUN command!"; - if (!unexpected_run_fail_sent) { - DLOG(WARNING) << "Couldn't send failure message!"; - return State::Close; - } else { - return State::ErrorIdle; - } + return State::Close; } debug_assert(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); + DLOG(INFO) << fmt::format("[Run] '{}'", query.Value()); + bool in_explicit_transaction = false; + if (session.db_accessor_) { + // Transaction already exists. + in_explicit_transaction = true; + } else { + // Create new transaction. + session.db_accessor_ = session.dbms_.active(); + } + + DLOG(INFO) << fmt::format("[ActiveDB] '{}'", session.db_accessor_->name()); + + // If there was not explicitly started transaction before maybe we are + // starting one now. + if (!in_explicit_transaction && query.Value() == "BEGIN") { + // Check if query string is "BEGIN". If it is then we should start + // transaction and wait for in-transaction queries. + // TODO: "BEGIN" is not defined by bolt protocol or opencypher so we should + // test if all drivers really denote transaction start with "BEGIN" string. + // Same goes for "ROLLBACK" and "COMMIT". + // + // One MessageSuccess for RUN command should be flushed. + session.encoder_.MessageSuccess(); + // One for PULL_ALL should be chunked. + session.encoder_.MessageSuccess({}, false); + return State::Result; + } + + if (in_explicit_transaction) { + if (query.Value() == "COMMIT") { + session.Commit(); + // One MessageSuccess for RUN command should be flushed. + session.encoder_.MessageSuccess(); + // One for PULL_ALL should be chunked. + session.encoder_.MessageSuccess({}, false); + return State::Result; + } else if (query.Value() == "ROLLBACK") { + session.Abort(); + // One MessageSuccess for RUN command should be flushed. + session.encoder_.MessageSuccess(); + // One for PULL_ALL should be chunked. + session.encoder_.MessageSuccess({}, false); + return State::Result; + } + session.db_accessor_->advance_command(); + } + try { - DLOG(INFO) << fmt::format("[Run] '{}'", query.Value()); auto is_successfully_executed = session.query_engine_.Run( - query.Value(), *db_accessor, session.output_stream_, + query.Value(), *session.db_accessor_, + session.output_stream_, params.Value>()); + // TODO: once we remove compiler from query_engine we can change return type + // to void and not do this checks here. if (!is_successfully_executed) { - // abort transaction - db_accessor->abort(); + if (!in_explicit_transaction) { + session.Abort(); + } // clear any leftover messages in the buffer session.encoder_buffer_.Clear(); @@ -86,26 +138,33 @@ State HandleRun(Session &session, State state, Marker marker) { if (!exec_fail_sent) { DLOG(WARNING) << "Couldn't send failure message!"; return State::Close; - } else { - return State::ErrorIdle; } - } else { - db_accessor->commit(); - // The query engine has already stored all query data in the buffer. - // We should only send the first chunk now which is the success - // message which contains header data. The rest of this data (records - // and summary) will be sent after a PULL_ALL command from the client. - if (!session.encoder_buffer_.FlushFirstChunk()) { - DLOG(WARNING) << "Couldn't flush header data from the buffer!"; - return State::Close; + if (in_explicit_transaction) { + // TODO: Neo4j only discards changes from last query and can possible + // continue. We can't discard changes from one or multiple commands in + // same transaction so we need to rollback whole transaction. One day + // we should probably support neo4j's way. + return State::ErrorWaitForRollback; } - return State::Result; + return State::ErrorIdle; } + if (!in_explicit_transaction) { + session.Commit(); + } + // The query engine has already stored all query data in the buffer. + // We should only send the first chunk now which is the success + // message which contains header data. The rest of this data (records + // and summary) will be sent after a PULL_ALL command from the client. + if (!session.encoder_buffer_.FlushFirstChunk()) { + DLOG(WARNING) << "Couldn't flush header data from the buffer!"; + return State::Close; + } + return State::Result; + // TODO: Remove duplication in error handling. } catch (const utils::BasicException &e) { // clear header success message session.encoder_buffer_.Clear(); - db_accessor->abort(); bool fail_sent = session.encoder_.MessageFailure( {{"code", "Memgraph.Exception"}, {"message", e.what()}}); DLOG(WARNING) << fmt::format("Error message: {}", e.what()); @@ -113,12 +172,14 @@ State HandleRun(Session &session, State state, Marker marker) { DLOG(WARNING) << "Couldn't send failure message!"; return State::Close; } - return State::ErrorIdle; - + if (!in_explicit_transaction) { + session.Abort(); + return State::ErrorIdle; + } + return State::ErrorWaitForRollback; } catch (const utils::StacktraceException &e) { // clear header success message session.encoder_buffer_.Clear(); - db_accessor->abort(); bool fail_sent = session.encoder_.MessageFailure( {{"code", "Memgraph.Exception"}, {"message", e.what()}}); DLOG(WARNING) << fmt::format("Error message: {}", e.what()); @@ -127,12 +188,14 @@ State HandleRun(Session &session, State state, Marker marker) { DLOG(WARNING) << "Couldn't send failure message!"; return State::Close; } - return State::ErrorIdle; - - } catch (std::exception &e) { + if (!in_explicit_transaction) { + session.Abort(); + return State::ErrorIdle; + } + return State::ErrorWaitForRollback; + } catch (const std::exception &e) { // clear header success message session.encoder_buffer_.Clear(); - db_accessor->abort(); bool fail_sent = session.encoder_.MessageFailure( {{"code", "Memgraph.Exception"}, {"message", @@ -143,10 +206,15 @@ State HandleRun(Session &session, State state, Marker marker) { DLOG(WARNING) << "Couldn't send failure message!"; return State::Close; } - return State::ErrorIdle; + if (!in_explicit_transaction) { + session.Abort(); + return State::ErrorIdle; + } + return State::ErrorWaitForRollback; } } +// TODO: Get rid of duplications in PullAll/DiscardAll functions. template State HandlePullAll(Session &session, State state, Marker marker) { DLOG(INFO) << "[PullAll]"; @@ -156,21 +224,14 @@ State HandlePullAll(Session &session, State state, Marker marker) { underlying_cast(marker)); return State::Close; } + if (state != State::Result) { - // the buffer doesn't have data, return a failure message - bool data_fail_sent = session.encoder_.MessageFailure( - {{"code", "Memgraph.Exception"}, - {"message", - "There is no data to " - "send, you have to execute a RUN command before a PULL_ALL!"}}); - if (!data_fail_sent) { - DLOG(WARNING) << "Couldn't send failure message!"; - return State::Close; - } - return State::ErrorIdle; + DLOG(WARNING) << "Unexpected PULL_ALL!"; + // Same as `unexpected RUN` case. + return State::Close; } - // flush pending data to the client, the success message is streamed - // from the query engine, it contains statistics from the query run + // Flush pending data to the client, the success message is streamed + // from the query engine, it contains statistics from the query run. if (!session.encoder_buffer_.Flush()) { DLOG(WARNING) << "Couldn't flush data from the buffer!"; return State::Close; @@ -189,19 +250,11 @@ State HandleDiscardAll(Session &session, State state, Marker marker) { } if (state != State::Result) { - bool data_fail_discard = session.encoder_.MessageFailure( - {{"code", "Memgraph.Exception"}, - {"message", - "There is no data to " - "discard, you have to execute a RUN command before a " - "DISCARD_ALL!"}}); - if (!data_fail_discard) { - DLOG(WARNING) << "Couldn't send failure message!"; - return State::Close; - } - return State::ErrorIdle; + DLOG(WARNING) << "Unexpected DISCARD_ALL!"; + // Same as `unexpected RUN` case. + return State::Close; } - // clear all pending data and send a success message + // Clear all pending data and send a success message. session.encoder_buffer_.Clear(); if (!session.encoder_.MessageSuccess()) { DLOG(WARNING) << "Couldn't send success message!"; @@ -233,6 +286,9 @@ State HandleReset(Session &session, State, Marker marker) { DLOG(WARNING) << "Couldn't send success message!"; return State::Close; } + if (session.db_accessor_) { + session.Abort(); + } return State::Idle; } @@ -243,7 +299,7 @@ State HandleReset(Session &session, State, Marker marker) { * @param session the session that should be used for the run */ template -State StateIdleResultRun(Session &session, State state) { +State StateExecutingRun(Session &session, State state) { Marker marker; Signature signature; if (!session.decoder_.ReadMessageHeader(&signature, &marker)) { diff --git a/tests/drivers/python/transactions.py b/tests/drivers/python/transactions.py new file mode 100644 index 000000000..94ff95e01 --- /dev/null +++ b/tests/drivers/python/transactions.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from neo4j.v1 import GraphDatabase, basic_auth, CypherError + +driver = GraphDatabase.driver("bolt://localhost:7687", + auth=basic_auth("", ""), + encrypted=False) + +def tx_error(tx, name, name2): + a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name).data() + print(a[0]['a']) + tx.run("CREATE (").consume() + a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name2).data() + print(a[0]['a']) + +def tx_good(tx, name, name2): + a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name).data() + print(a[0]['a']) + a = tx.run("CREATE (a:Person {name: $name}) RETURN a", name=name2).data() + print(a[0]['a']) + +def add_person(f, name, name2): + with driver.session() as session: + session.write_transaction(f, name, name2) + +try: + add_person(tx_error, "mirko", "slavko") +except CypherError: + pass + +add_person(tx_good, "mirka", "slavka") + +driver.close() + +print("All ok!") diff --git a/tests/unit/bolt_session.cpp b/tests/unit/bolt_session.cpp index 0ff88ba0f..bc693481f 100644 --- a/tests/unit/bolt_session.cpp +++ b/tests/unit/bolt_session.cpp @@ -8,6 +8,7 @@ DECLARE_bool(interpret); +// TODO: This could be done in fixture. // Shortcuts for writing variable initializations in tests #define INIT_VARS \ Dbms dbms; \ @@ -54,12 +55,33 @@ void WriteChunkHeader(SessionT &session, uint16_t len) { // Write bolt chunk tail (two zeros) void WriteChunkTail(SessionT &session) { WriteChunkHeader(session, 0); } -// Check that the server responded with a failure message +// Check that the server responded with a failure message. void CheckFailureMessage(std::vector &output) { ASSERT_GE(output.size(), 6); // skip the first two bytes because they are the chunk header ASSERT_EQ(output[2], 0xB1); // tiny struct 1 ASSERT_EQ(output[3], 0x7F); // signature failure + output.clear(); +} + +// Check that the server responded with a success message. +void CheckSuccessMessage(std::vector &output, bool clear = true) { + ASSERT_GE(output.size(), 6); + // skip the first two bytes because they are the chunk header + ASSERT_EQ(output[2], 0xB1); // tiny struct 1 + ASSERT_EQ(output[3], 0x70); // signature success + if (clear) { + output.clear(); + } +} + +// Check that the server responded with a ignore message. +void CheckIgnoreMessage(std::vector &output) { + ASSERT_GE(output.size(), 6); + // skip the first two bytes because they are the chunk header + ASSERT_EQ(output[2], 0xB0); + ASSERT_EQ(output[3], 0x7E); // signature ignore + output.clear(); } // Execute and check a correct handshake @@ -330,13 +352,11 @@ TEST(BoltSession, ExecutePullAllBufferEmpty) { session.socket_.SetWriteSuccess(i == 0); ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); if (i == 0) { - ASSERT_EQ(session.state_, StateT::ErrorIdle); - ASSERT_TRUE(session.socket_.IsOpen()); CheckFailureMessage(output); } else { - ASSERT_EQ(session.state_, StateT::Close); - ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_EQ(output.size(), 0); } } @@ -440,11 +460,8 @@ TEST(BoltSession, ErrorRunAfterRun) { WriteRunRequest(session, "MATCH (n) RETURN n"); session.Execute(); - // Run after run fails, but we still keep results. - // TODO: actually we don't, but we should. Change state to ErrorResult once - // that is fixed. - ASSERT_EQ(session.state_, StateT::ErrorIdle); - ASSERT_TRUE(session.socket_.IsOpen()); + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); } TEST(BoltSession, ErrorCantCleanup) { @@ -591,8 +608,8 @@ TEST(BoltSession, PartialChunk) { session.Execute(); - ASSERT_EQ(session.state_, StateT::ErrorIdle); - ASSERT_TRUE(session.socket_.IsOpen()); + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.socket_.IsOpen()); ASSERT_GT(output.size(), 0); PrintOutput(output); } @@ -613,6 +630,121 @@ TEST(BoltSession, InvalidChunk) { CheckFailureMessage(output); } +TEST(BoltSession, ExplicitTransactionValidQueries) { + // It is not really easy to check if we commited or aborted transaction except + // by faking GraphDb/TxEngine... + std::vector transaction_ends = {"COMMIT", "ROLLBACK"}; + + for (const auto &transaction_end : transaction_ends) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "BEGIN"); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Result); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + WriteRunRequest(session, "MATCH (n) RETURN n"); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Result); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + WriteRunRequest(session, transaction_end.c_str()); + session.Execute(); + ASSERT_FALSE(session.db_accessor_); + CheckSuccessMessage(output); + ASSERT_EQ(session.state_, StateT::Result); + + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_FALSE(session.db_accessor_); + CheckSuccessMessage(output); + + ASSERT_TRUE(session.socket_.IsOpen()); + } +} + +TEST(BoltSession, ExplicitTransactionInvalidQuery) { + std::vector transaction_ends = {"COMMIT", "ROLLBACK"}; + + for (const auto &transaction_end : transaction_ends) { + INIT_VARS; + + ExecuteHandshake(session, output); + ExecuteInit(session, output); + + WriteRunRequest(session, "BEGIN"); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Result); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + WriteRunRequest(session, "MATCH ("); + session.Execute(); + ASSERT_EQ(session.state_, StateT::ErrorWaitForRollback); + ASSERT_TRUE(session.db_accessor_); + CheckFailureMessage(output); + + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::ErrorWaitForRollback); + ASSERT_TRUE(session.db_accessor_); + CheckIgnoreMessage(output); + + ExecuteCommand(session, ackfailure_req, sizeof(ackfailure_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::WaitForRollback); + ASSERT_TRUE(session.db_accessor_); + CheckSuccessMessage(output); + + WriteRunRequest(session, transaction_end.c_str()); + session.Execute(); + + if (transaction_end == "ROLLBACK") { + ASSERT_EQ(session.state_, StateT::Result); + ASSERT_FALSE(session.db_accessor_); + ASSERT_TRUE(session.socket_.IsOpen()); + CheckSuccessMessage(output); + + ExecuteCommand(session, pullall_req, sizeof(pullall_req)); + session.Execute(); + ASSERT_EQ(session.state_, StateT::Idle); + ASSERT_FALSE(session.db_accessor_); + ASSERT_TRUE(session.socket_.IsOpen()); + CheckSuccessMessage(output); + + } else { + ASSERT_EQ(session.state_, StateT::Close); + ASSERT_FALSE(session.db_accessor_); + ASSERT_FALSE(session.socket_.IsOpen()); + CheckFailureMessage(output); + } + } +} + int main(int argc, char **argv) { google::InitGoogleLogging(argv[0]); // Set the interpret to true to avoid calling the compiler which only