# Copyright 2023 Memgraph Ltd.
#
# Use of this software is governed by the Business Source License
# included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
# License, and you may not use this file except in compliance with the Business Source License.
#
# As of the Change Date specified in that file, in accordance with
# the Business Source License, use of this software will be governed
# by the Apache License, Version 2.0, included in the file
# licenses/APL.txt.

"""
Tests here work by executing the wanted procedure twice,
the first time the action will be commited to confirm the procedure executed
and doesn't crash for other reasons
and the second time the action will be rollbacked.
"""

import sys

import pytest
from common import execute_and_fetch_all


def test_change_from_rollback(connection):
    cursor = connection.cursor()

    execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
    execute_and_fetch_all(cursor, "CREATE (n:Node1) CREATE (m:Node2) CREATE (k:Node3) CREATE (n)-[:Relationship]->(m);")
    connection.commit()

    execute_and_fetch_all(
        cursor,
        "MATCH (n:Node1)-[r:Relationship]->(m:Node2) MATCH (k:Node3) CALL transaction_rollback.set_from(r, k);",
    )
    connection.commit()

    def compare(from_label: str, to_label: str):
        result = list(execute_and_fetch_all(cursor, f"MATCH (n)-[r]->(m) RETURN n, r, m"))
        assert len(result) == 1
        node_from, _, node_to = result[0]
        assert list(node_from.labels)[0] == from_label
        assert list(node_to.labels)[0] == to_label

    compare("Node3", "Node2")

    execute_and_fetch_all(
        cursor,
        "MATCH (n:Node3)-[r:Relationship]->(m:Node2) MATCH (k:Node1) CALL transaction_rollback.set_from(r, k);",
    )
    connection.rollback()
    compare("Node3", "Node2")

    execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n;")
    connection.rollback()
    compare("Node3", "Node2")


def test_change_to_rollback(connection):
    cursor = connection.cursor()

    execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
    execute_and_fetch_all(cursor, "CREATE (n:Node1) CREATE (m:Node2) CREATE (k:Node3) CREATE (n)-[:Relationship]->(m);")
    connection.commit()

    execute_and_fetch_all(
        cursor, "MATCH (n:Node1)-[r:Relationship]->(m:Node2) MATCH (k:Node3) CALL transaction_rollback.set_to(r, k);"
    )
    connection.commit()

    def compare(from_label: str, to_label: str):
        result = list(execute_and_fetch_all(cursor, f"MATCH (n)-[r]->(m) RETURN n, r, m"))
        assert len(result) == 1
        node_from, _, node_to = result[0]
        assert list(node_from.labels)[0] == from_label
        assert list(node_to.labels)[0] == to_label

    compare("Node1", "Node3")

    execute_and_fetch_all(
        cursor, "MATCH (n:Node1)-[r:Relationship]->(m:Node3) MATCH (k:Node2) CALL transaction_rollback.set_to(r, k);"
    )
    connection.rollback()
    compare("Node1", "Node3")

    execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n;")
    connection.rollback()
    compare("Node1", "Node3")


def test_change_rel_type_rollback(connection):
    cursor = connection.cursor()

    execute_and_fetch_all(cursor, "MATCH (n) DETACH DELETE n;")
    execute_and_fetch_all(cursor, "CREATE (n:Node1) CREATE (m:Node2) CREATE (n)-[:Relationship]->(m);")
    connection.commit()

    execute_and_fetch_all(
        cursor, "MATCH (n:Node1)-[r:Relationship]->(m:Node2) CALL transaction_rollback.change_type(r, 'Rel');"
    )
    connection.commit()

    def compare(rel_type: str):
        result = list(execute_and_fetch_all(cursor, f"MATCH (n)-[r]->(m) RETURN r"))
        assert len(result) == 1
        rel = result[0][0]
        assert rel.type == rel_type

    compare("Rel")

    execute_and_fetch_all(
        cursor, "MATCH (n:Node1)-[r:Rel]->(m:Node2) CALL transaction_rollback.change_type(r, 'Relationship');"
    )
    connection.rollback()

    compare("Rel")

    execute_and_fetch_all(cursor, f"MATCH (n) DETACH DELETE n;")
    connection.rollback()

    compare("Rel")


if __name__ == "__main__":
    sys.exit(pytest.main([__file__, "-rA"]))