#!/usr/bin/python3 -u
import argparse
import json
import os
import signal
import sys
import time
import itertools

from http.server import BaseHTTPRequestHandler, HTTPServer


def build_handler(storage, args):
    class Handler(BaseHTTPRequestHandler):
        def do_HEAD(self):
            assert False

        def do_GET(self):
            assert False

        def do_PUT(self):
            assert False

        def do_POST(self):
            if args.redirect and self.path == args.path:
                # 307 is used instead of 301 to preserve body data
                # https://stackoverflow.com/questions/19070801/curl-loses-body-when-a-post-redirected-from-http-to-https
                self.send_response(307)
                self.send_header("Location", args.redirect_path)
                self.end_headers()
                return

            assert self.headers["user-agent"] == "memgraph/telemetry"
            assert self.headers["accept"] == "application/json"
            assert self.headers["content-type"] == "application/json"

            content_len = int(self.headers.get('content-length', 0))
            data = json.loads(self.rfile.read(content_len).decode("utf-8"))

            if self.path not in [args.path, args.redirect_path]:
                self.send_response(404)
                self.end_headers()
                return

            if args.no_response_count > 0:
                args.no_response_count -= 1
                return

            if args.wrong_code_count > 0:
                args.wrong_code_count -= 1
                self.send_response(500)
                self.end_headers()
                return

            assert type(data) == list

            for item in data:
                assert type(item) == dict
                assert "event" in item
                assert "run_id" in item
                assert "machine_id" in item
                assert "data" in item
                assert "timestamp" in item
                storage.append(item)

            if args.hang:
                time.sleep(20)

            self.send_response(200)
            self.end_headers()

    return Handler


class Server(HTTPServer):
    def handle_error(self, request, client_address):
        super().handle_error(request, client_address)
        os._exit(1)

    def shutdown(self):
        # TODO: this is a hack. The parent object implementation of this
        # function sets the shutdown flag and then waits for the shutdown to
        # complete.  We only need to set the shutdown flag because we don't
        # want to run the server in another thread. The parent implementation
        # can be seen here:
        # https://github.com/python/cpython/blob/3.5/Lib/socketserver.py#L241
        self._BaseServer__shutdown_request = True


def item_sort_key(obj):
    if type(obj) != dict:
        return -1
    if "timestamp" not in obj:
        return -1
    return obj["timestamp"]


def verify_storage(storage, args):
    rid = storage[0]["run_id"]
    timestamp = 0
    for i, item in enumerate(storage):
        assert item["run_id"] == rid

        assert item["timestamp"] >= timestamp
        timestamp = item["timestamp"]

        if i == 0:
            assert item["event"] == "startup"
        elif i == len(storage) - 1:
            assert item["event"] == "shutdown"
        else:
            assert item["event"] == i - 1

        if i == 0:
            assert "architecture" in item["data"]
            assert "cpu_count" in item["data"]
            assert "cpu_model" in item["data"]
            assert "kernel" in item["data"]
            assert "memory" in item["data"]
            assert "os" in item["data"]
            assert "swap" in item["data"]
        else:
            assert item["data"]["db"]["vertices"] == i
            assert item["data"]["db"]["edges"] == i

            assert "resources" in item["data"]
            assert "cpu" in item["data"]["resources"]
            assert "memory" in item["data"]["resources"]
            assert "uptime" in item["data"]

            uptime = item["data"]["uptime"]
            expected = i * args.interval
            if i == len(storage) - 1:
                if not args.no_check_duration:
                    expected = args.duration
                else:
                    expected = uptime
            assert uptime >= expected - 4 and uptime <= expected + 4


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--address", type=str, default="127.0.0.1")
    parser.add_argument("--port", type=int, default=9000)
    parser.add_argument("--path", type=str, default="/")
    parser.add_argument("--redirect", action="store_true")
    parser.add_argument("--no-response-count", type=int, default=0)
    parser.add_argument("--wrong-code-count", type=int, default=0)
    parser.add_argument("--no-check", action="store_true")
    parser.add_argument("--hang", action="store_true")
    parser.add_argument("--interval", type=int, default=1)
    parser.add_argument("--duration", type=int, default=10)
    parser.add_argument("--startups", type=int, default=1)
    parser.add_argument("--no-check-duration", action="store_true")
    args = parser.parse_args()
    args.redirect_path = os.path.join(args.path, "redirect")

    storage = []
    handler = build_handler(storage, args)
    httpd = Server((args.address, args.port), handler)
    signal.signal(signal.SIGTERM, lambda signum, frame: httpd.shutdown())
    httpd.serve_forever()
    httpd.server_close()

    if args.no_check:
        sys.exit(0)

    # Order the received data.
    storage.sort(key=item_sort_key)

    # Split the data into individual startups.
    startups = [[storage[0]]]
    for item in storage[1:]:
        if item["run_id"] != startups[-1][-1]["run_id"]:
            startups.append([])
        startups[-1].append(item)

    # Check that there were the correct number of startups.
    assert len(startups) == args.startups

    # Verify each startup.
    for startup in startups:
        verify_storage(startup, args)

    # machine id has to be same for every run on the same machine
    assert len(set(map(lambda x: x['machine_id'], itertools.chain(*startups)))) == 1