From 03f460127e874a8d294b9f18a93c0dcc78d5502e Mon Sep 17 00:00:00 2001
From: Matej Ferencevic <matej.ferencevic@memgraph.io>
Date: Fri, 21 Sep 2018 16:53:27 +0200
Subject: [PATCH] Don't kill active sessions with inactivity timeout

Reviewers: teon.banek, buda

Reviewed By: buda

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1608
---
 src/communication/session.hpp   | 12 +++++----
 tests/unit/network_timeouts.cpp | 43 +++++++++++++++++++++++++++++----
 2 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/src/communication/session.hpp b/src/communication/session.hpp
index d91befc85..4ed1425fc 100644
--- a/src/communication/session.hpp
+++ b/src/communication/session.hpp
@@ -19,6 +19,7 @@
 #include "communication/helpers.hpp"
 #include "io/network/socket.hpp"
 #include "io/network/stream_buffer.hpp"
+#include "utils/on_scope_exit.hpp"
 #include "utils/thread/sync.hpp"
 
 namespace communication {
@@ -142,7 +143,8 @@ class Session final {
    */
   bool Execute() {
     // Refresh the last event time in the session.
-    RefreshLastEventTime();
+    RefreshLastEventTime(true);
+    utils::OnScopeExit on_exit([this] { RefreshLastEventTime(false); });
 
     // Allocate the buffer to fill the data.
     auto buf = input_buffer_.write_end()->Allocate();
@@ -212,9 +214,6 @@ class Session final {
     // Execute the session.
     session_.Execute();
 
-    // Refresh the last event time.
-    RefreshLastEventTime();
-
     return false;
   }
 
@@ -226,6 +225,7 @@ class Session final {
    */
   bool TimedOut() {
     std::unique_lock<utils::SpinLock> guard(lock_);
+    if (execution_active_) return false;
     return last_event_time_ + std::chrono::seconds(inactivity_timeout_sec_) <
            std::chrono::steady_clock::now();
   }
@@ -236,8 +236,9 @@ class Session final {
   io::network::Socket &socket() { return socket_; }
 
  private:
-  void RefreshLastEventTime() {
+  void RefreshLastEventTime(bool active) {
     std::unique_lock<utils::SpinLock> guard(lock_);
+    execution_active_ = active;
     last_event_time_ = std::chrono::steady_clock::now();
   }
 
@@ -300,6 +301,7 @@ class Session final {
   // Time of the last event and associated lock.
   std::chrono::time_point<std::chrono::steady_clock> last_event_time_{
       std::chrono::steady_clock::now()};
+  bool execution_active_{false};
   utils::SpinLock lock_;
   const int inactivity_timeout_sec_;
 
diff --git a/tests/unit/network_timeouts.cpp b/tests/unit/network_timeouts.cpp
index 0a7315b8a..4c7bee675 100644
--- a/tests/unit/network_timeouts.cpp
+++ b/tests/unit/network_timeouts.cpp
@@ -1,5 +1,6 @@
 #include <chrono>
 #include <iostream>
+#include <thread>
 
 #include <gflags/gflags.h>
 #include <glog/logging.h>
@@ -25,6 +26,9 @@ class TestSession {
                      reinterpret_cast<const char *>(input_stream_->data()),
                      input_stream_->size())
               << "'";
+    if (input_stream_->data()[0] == 'e') {
+      std::this_thread::sleep_for(std::chrono::seconds(5));
+    }
     output_stream_->Write(input_stream_->data(), input_stream_->size());
     input_stream_->Shift(input_stream_->size());
   }
@@ -34,9 +38,10 @@ class TestSession {
   communication::OutputStream *output_stream_;
 };
 
-const std::string query("timeout test");
+const std::string safe_query("tttt");
+const std::string expensive_query("eeee");
 
-bool QueryServer(io::network::Socket &socket) {
+bool QueryServer(io::network::Socket &socket, const std::string &query) {
   if (!socket.Write(query)) return false;
   char response[105];
   int len = 0;
@@ -62,19 +67,47 @@ TEST(NetworkTimeouts, InactiveSession) {
   ASSERT_TRUE(client.Connect(server.endpoint()));
 
   // Send some data to the server.
-  ASSERT_TRUE(QueryServer(client));
+  ASSERT_TRUE(QueryServer(client, safe_query));
 
   for (int i = 0; i < 3; ++i) {
     // After this sleep the session should still be alive.
     std::this_thread::sleep_for(500ms);
 
     // Send some data to the server.
-    ASSERT_TRUE(QueryServer(client));
+    ASSERT_TRUE(QueryServer(client, safe_query));
   }
 
   // After this sleep the session should have timed out.
   std::this_thread::sleep_for(3500ms);
-  ASSERT_FALSE(QueryServer(client));
+  ASSERT_FALSE(QueryServer(client, safe_query));
+}
+
+TEST(NetworkTimeouts, ActiveSession) {
+  // Instantiate the server and set the session timeout to 2 seconds.
+  TestData test_data;
+  communication::ServerContext context;
+  communication::Server<TestSession, TestData> server{
+      {"127.0.0.1", 0}, &test_data, &context, 2, "Test", 1};
+
+  // Create the client and connect to the server.
+  io::network::Socket client;
+  ASSERT_TRUE(client.Connect(server.endpoint()));
+
+  // Send some data to the server.
+  ASSERT_TRUE(QueryServer(client, expensive_query));
+
+  for (int i = 0; i < 3; ++i) {
+    // After this sleep the session should still be alive.
+    std::this_thread::sleep_for(500ms);
+
+    // Send some data to the server.
+    ASSERT_TRUE(QueryServer(client, safe_query));
+  }
+
+  // After this sleep the session should have timed out.
+  std::this_thread::sleep_for(3500ms);
+  ASSERT_FALSE(QueryServer(client, safe_query));
+
 }
 
 int main(int argc, char **argv) {