From 7beeab6964639aefe8869b0ac256bd0a6af56e4f Mon Sep 17 00:00:00 2001
From: imilinovic <ivan.milinovic123@gmail.com>
Date: Thu, 14 Mar 2024 01:38:52 +0100
Subject: [PATCH] add tests and fix wrong logic

---
 src/storage/v2/disk/storage.cpp           | 11 ++--
 tests/e2e/CMakeLists.txt                  |  1 +
 tests/e2e/concurrent_write/CMakeLists.txt |  7 +++
 tests/e2e/concurrent_write/client.py      | 33 ++++++++++
 tests/e2e/concurrent_write/workloads.yaml | 33 ++++++++++
 tests/e2e/concurrent_write/write.py       | 73 +++++++++++++++++++++++
 6 files changed, 152 insertions(+), 6 deletions(-)
 create mode 100644 tests/e2e/concurrent_write/CMakeLists.txt
 create mode 100644 tests/e2e/concurrent_write/client.py
 create mode 100644 tests/e2e/concurrent_write/workloads.yaml
 create mode 100644 tests/e2e/concurrent_write/write.py

diff --git a/src/storage/v2/disk/storage.cpp b/src/storage/v2/disk/storage.cpp
index 3d99fe770..eaeb780c8 100644
--- a/src/storage/v2/disk/storage.cpp
+++ b/src/storage/v2/disk/storage.cpp
@@ -1762,15 +1762,14 @@ utils::BasicResult<StorageManipulationError, void> DiskStorage::DiskAccessor::Co
   }
   auto commitStatus = transaction_.disk_transaction_->Commit();
   if (!commitStatus.ok()) {
-    transaction_.disk_transaction_->Rollback();
-  }
-  delete transaction_.disk_transaction_;
-  transaction_.disk_transaction_ = nullptr;
-
-  if (!commitStatus.ok()) {
+    Abort();
     spdlog::error("rocksdb: Commit failed with status {}", commitStatus.ToString());
     return StorageManipulationError{SerializationError{}};
   }
+
+  delete transaction_.disk_transaction_;
+  transaction_.disk_transaction_ = nullptr;
+
   spdlog::trace("rocksdb: Commit successful");
 
   is_transaction_active_ = false;
diff --git a/tests/e2e/CMakeLists.txt b/tests/e2e/CMakeLists.txt
index 1876074ee..8878c880e 100644
--- a/tests/e2e/CMakeLists.txt
+++ b/tests/e2e/CMakeLists.txt
@@ -77,6 +77,7 @@ add_subdirectory(garbage_collection)
 add_subdirectory(query_planning)
 add_subdirectory(awesome_functions)
 add_subdirectory(high_availability)
+add_subdirectory(concurrent_write)
 
 add_subdirectory(replication_experimental)
 
diff --git a/tests/e2e/concurrent_write/CMakeLists.txt b/tests/e2e/concurrent_write/CMakeLists.txt
new file mode 100644
index 000000000..1546eedee
--- /dev/null
+++ b/tests/e2e/concurrent_write/CMakeLists.txt
@@ -0,0 +1,7 @@
+function(copy_concurrent_write_e2e_python_files FILE_NAME)
+    copy_e2e_python_files(concurrent_write ${FILE_NAME})
+endfunction()
+
+copy_concurrent_write_e2e_python_files(write.py)
+
+copy_e2e_files(concurrent_write workloads.yaml)
diff --git a/tests/e2e/concurrent_write/client.py b/tests/e2e/concurrent_write/client.py
new file mode 100644
index 000000000..4349cfcc8
--- /dev/null
+++ b/tests/e2e/concurrent_write/client.py
@@ -0,0 +1,33 @@
+import multiprocessing
+
+import mgclient
+import pytest
+
+
+def inner(query, number_of_executions):
+    connection = mgclient.connect(host="127.0.0.1", port=7687)
+    connection.autocommit = False
+    cursor = connection.cursor()
+    for _ in range(number_of_executions):
+        cursor.execute(query)
+    cursor.fetchall()
+
+
+class MemgraphClient:
+    def __init__(self):
+        self.query_list = []
+
+    def initialize_to_execute(self, query: str, number_of_executions):
+        self.query_list.append((query, number_of_executions))
+
+    def execute_queries(self):
+        num_processes = len(self.query_list)
+        with multiprocessing.Pool(processes=num_processes) as pool:
+            pool.starmap(inner, self.query_list)
+
+        return True
+
+
+@pytest.fixture
+def client() -> MemgraphClient:
+    return MemgraphClient()
diff --git a/tests/e2e/concurrent_write/workloads.yaml b/tests/e2e/concurrent_write/workloads.yaml
new file mode 100644
index 000000000..18dfac5d0
--- /dev/null
+++ b/tests/e2e/concurrent_write/workloads.yaml
@@ -0,0 +1,33 @@
+args: &args
+ - "--bolt-port"
+ - "7687"
+ - "--log-level"
+ - "TRACE"
+
+in_memory_cluster: &in_memory_cluster
+  cluster:
+    main:
+      args: *args
+      log_file: "concurrent-write-e2e.log"
+      setup_queries: []
+      validation_queries: []
+
+disk_cluster: &disk_cluster
+  cluster:
+    main:
+      args: *args
+      log_file: "concurrent-write-e2e.log"
+      setup_queries: ["STORAGE MODE ON_DISK_TRANSACTIONAL"]
+      validation_queries: []
+
+workloads:
+  - name: "Concurrent write"
+    binary: "tests/e2e/pytest_runner.sh"
+    proc: "tests/e2e/concurrent_write/test_query_modules/"
+    args: ["concurrent_write/write.py"]
+    <<: *in_memory_cluster
+  - name: "Disk concurrent write"
+    binary: "tests/e2e/pytest_runner.sh"
+    proc: "tests/e2e/concurrent_write/test_query_modules/"
+    args: ["concurrent_write/write.py"]
+    <<: *disk_cluster
diff --git a/tests/e2e/concurrent_write/write.py b/tests/e2e/concurrent_write/write.py
new file mode 100644
index 000000000..87528bbfc
--- /dev/null
+++ b/tests/e2e/concurrent_write/write.py
@@ -0,0 +1,73 @@
+import sys
+import threading
+import time
+import typing
+
+import mgclient
+import pytest
+
+
+def execute_and_fetch_all(cursor: mgclient.Cursor, query: str, params: dict = {}) -> typing.List[tuple]:
+    cursor.execute(query, params)
+    return cursor.fetchall()
+
+
+commit_success_lock = threading.Lock()
+commit_fail_lock = threading.Lock()
+
+
+def client_success():
+    commit_fail_lock.acquire()
+    time.sleep(0.1)
+    connection = mgclient.connect(host="localhost", port=7687)
+    connection.autocommit = False
+
+    cursor = connection.cursor()
+
+    execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
+    execute_and_fetch_all(cursor, "CREATE (:N1), (:N2);")
+    connection.commit()
+
+    execute_and_fetch_all(cursor, "MATCH (n1:N1) DELETE n1;")
+    commit_success_lock.acquire()
+    commit_fail_lock.release()
+    connection.commit()
+    commit_success_lock.release()
+
+
+def client_fail():
+    try:
+        commit_success_lock.acquire()
+        connection = mgclient.connect(host="localhost", port=7687)
+        connection.autocommit = False
+        cursor = connection.cursor()
+
+        execute_and_fetch_all(cursor, "MATCH (n1:N1), (n2:N2) CREATE (n1)-[:R]->(n2);")
+        commit_success_lock.release()
+        commit_fail_lock.acquire()
+        connection.commit()
+    except mgclient.DatabaseError:
+        commit_fail_lock.release()
+
+
+def test_concurrent_write():
+    t1 = threading.Thread(target=client_success)
+    t2 = threading.Thread(target=client_fail)
+
+    t1.start()
+    t2.start()
+
+    t1.join()
+    t2.join()
+
+    connection = mgclient.connect(host="localhost", port=7687)
+    connection.autocommit = True
+    cursor = connection.cursor()
+    assert execute_and_fetch_all(cursor, "MATCH (n:N1) RETURN inDegree(n);") == []
+    assert execute_and_fetch_all(cursor, "MATCH (n:N1) RETURN outDegree(n);") == []
+    assert execute_and_fetch_all(cursor, "MATCH (n:N2) RETURN inDegree(n);")[0][0] == 0
+    assert execute_and_fetch_all(cursor, "MATCH (n:N2) RETURN outDegree(n);")[0][0] == 0
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main([__file__, "-rA"]))