diff --git a/src/communication/bolt/client.cpp b/src/communication/bolt/client.cpp index 20e2699c4..5a8dbf9d2 100644 --- a/src/communication/bolt/client.cpp +++ b/src/communication/bolt/client.cpp @@ -146,9 +146,10 @@ QueryData Client::Execute(const std::string &query, const std::map ExceptionToErrorMessage(const std::ex namespace details { -template -State HandleRun(TSession &session, const State state, const Value &query, const Value ¶ms) { - if (state != State::Idle) { - // Client could potentially recover if we move to error state, but there is - // no legitimate situation in which well working client would end up in this - // situation. - spdlog::trace("Unexpected RUN command!"); - return State::Close; - } - - DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); - - spdlog::debug("[Run] '{}'", query.ValueString()); - - try { - // Interpret can throw. - const auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap()); - // Convert std::string to Value - std::vector vec; - std::map data; - vec.reserve(header.size()); - for (auto &i : header) vec.emplace_back(std::move(i)); - data.emplace("fields", std::move(vec)); - // Send the header. - if (!session.encoder_.MessageSuccess(data)) { - spdlog::trace("Couldn't send query header!"); - return State::Close; - } - return State::Result; - } catch (const std::exception &e) { - return HandleFailure(session, e); - } -} - template State HandlePullDiscard(TSession &session, std::optional n, std::optional qid) { try { @@ -229,7 +195,36 @@ State HandleRunV1(TSession &session, const State state, const Marker marker) { return State::Close; } - return details::HandleRun(session, state, query, params); + if (state != State::Idle) { + // Client could potentially recover if we move to error state, but there is + // no legitimate situation in which well working client would end up in this + // situation. + spdlog::trace("Unexpected RUN command!"); + return State::Close; + } + + DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); + + spdlog::debug("[Run] '{}'", query.ValueString()); + + try { + // Interpret can throw. + const auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap()); + // Convert std::string to Value + std::vector vec; + std::map data; + vec.reserve(header.size()); + for (auto &i : header) vec.emplace_back(std::move(i)); + data.emplace("fields", std::move(vec)); + // Send the header. + if (!session.encoder_.MessageSuccess(data)) { + spdlog::trace("Couldn't send query header!"); + return State::Close; + } + return State::Result; + } catch (const std::exception &e) { + return HandleFailure(session, e); + } } template @@ -257,7 +252,40 @@ State HandleRunV4(TSession &session, const State state, const Marker marker) { spdlog::trace("Couldn't read extra field!"); } - return details::HandleRun(session, state, query, params); + if (state != State::Idle) { + // Client could potentially recover if we move to error state, but there is + // no legitimate situation in which well working client would end up in this + // situation. + spdlog::trace("Unexpected RUN command!"); + return State::Close; + } + + DMG_ASSERT(!session.encoder_buffer_.HasData(), "There should be no data to write in this state"); + + spdlog::debug("[Run] '{}'", query.ValueString()); + + try { + // Interpret can throw. + const auto [header, qid] = session.Interpret(query.ValueString(), params.ValueMap()); + // Convert std::string to Value + std::vector vec; + std::map data; + vec.reserve(header.size()); + for (auto &i : header) vec.emplace_back(std::move(i)); + data.emplace("fields", std::move(vec)); + if (qid.has_value()) { + data.emplace("qid", Value{*qid}); + } + + // Send the header. + if (!session.encoder_.MessageSuccess(data)) { + spdlog::trace("Couldn't send query header!"); + return State::Close; + } + return State::Result; + } catch (const std::exception &e) { + return HandleFailure(session, e); + } } template diff --git a/tests/integration/transactions/tester.cpp b/tests/integration/transactions/tester.cpp index 06f89202c..407b11892 100644 --- a/tests/integration/transactions/tester.cpp +++ b/tests/integration/transactions/tester.cpp @@ -45,6 +45,21 @@ class BoltClient : public ::testing::Test { return true; } + bool ExecuteAndCheckQid(const std::string &query, int qid, const std::string &message = "") { + try { + auto ret = client_.Execute(query, {}); + if (ret.metadata["qid"].ValueInt() != qid) { + return false; + } + } catch (const ClientQueryException &e) { + if (message != "") { + EXPECT_EQ(e.what(), message); + } + throw; + } + return true; + } + int64_t GetCount() { auto ret = client_.Execute("match (n) return count(n)", {}); EXPECT_EQ(ret.records.size(), 1); @@ -461,6 +476,18 @@ TEST_F(BoltClient, MixedCaseAndWhitespace) { EXPECT_FALSE(TransactionActive()); } +TEST_F(BoltClient, TestQid) { + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(Execute("match (n) return count(n)")); + } + EXPECT_TRUE(Execute("begin")); + for (int i = 0; i < 3; ++i) { + EXPECT_TRUE(ExecuteAndCheckQid("match (n) return count(n)", i + 1)); + } + EXPECT_TRUE(Execute("commit")); + EXPECT_FALSE(TransactionActive()); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); gflags::ParseCommandLineFlags(&argc, &argv, true);