#!/usr/bin/env python3
import argparse
import collections
import dpkt
import json
import operator
import os
import re
import socket
import subprocess
import struct
import sys
import tabulate

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, ".."))

TYPE_INFO_PATTERN = re.compile(r"""
 const\s+utils::TypeInfo\s+
 (\S+)\s*                     # <class-name>::kType
 {\s*
 (\S+)ULL\s*,\s*              # <id-hex>
 "(\S+)"\s*                   # "<class-name>"
 }
 """, re.VERBOSE)

# helpers

def format_endpoint(addr, port):
    return "{}:{}".format(socket.inet_ntoa(addr), port)


def parse_source_file(fname):
    ret = {}
    with open(fname) as f:
        for match in TYPE_INFO_PATTERN.finditer(f.read()):
            id_hex = int(match.groups()[1], 16)
            class_name = match.groups()[2]
            ret[id_hex] = class_name
    return ret


def parse_all_source_files(dirname):
    ids = {}
    ret = subprocess.run(["find", dirname, "-name", "*.lcp.cpp"],
                         stdout=subprocess.PIPE)
    ret.check_returncode()
    headers = list(filter(None, ret.stdout.decode("utf-8").split("\n")))
    for header in headers:
        ids.update(parse_source_file(header))
    return ids


MESSAGES = parse_all_source_files(os.path.join(PROJECT_DIR, "src"))


class Connection:
    # uint32_t message_size
    SIZE_FORMAT = "I"
    SIZE_LEN = struct.calcsize(SIZE_FORMAT)

    def __init__(self):
        self._previous = bytes()
        self._data = bytes()
        self._message = bytes()
        self._ts = []

        self._last = None
        self._stats = collections.defaultdict(lambda: {"duration": [],
                                                       "size": []})
        self._requests = []

    def _extract_message(self):
        if len(self._data) < self.SIZE_LEN:
            return False
        msg_len = struct.unpack_from(self.SIZE_FORMAT, self._data)[0]
        if len(self._data) < self.SIZE_LEN + msg_len:
            return False
        self._message = self._data[self.SIZE_LEN:]
        self._previous = self._data
        self._data = bytes()
        return True

    def add_data(self, data, direction, ts):
        if self._previous[-len(data):] == data \
                or self._data[-len(data):] == data:
            print("Retransmission detected!", file=sys.stderr)
            return

        self._data += data
        self._ts.append(ts)

        if not self._extract_message():
            return

        found = False
        for i in range(2, 6):
            if len(self._message) < (i + 1) * 8:
                continue
            message_id = struct.unpack("<Q",
                                       self._message[i * 8:(i + 1) * 8])[0]
            if message_id in MESSAGES:
                found = True
                break

        if not found:
            print("Got a message that I can't identify as any known "
                  "RPC request/response!", file=sys.stderr)
            self._last = None
            self._ts = []
            return

        message_type = MESSAGES[message_id]

        if direction == "to":
            self._requests.append((self._ts[-1], direction, message_type,
                                   len(self._message)))
        else:
            self._requests.append((self._ts[0], direction, message_type,
                                   len(self._message)))

        if self._last is None:
            self._last = (message_type, self._ts[0], len(self._message))
        else:
            req_type, req_ts, req_size = self._last
            duration = self._ts[-1] - req_ts
            self._stats[(req_type, message_type)]["duration"].append(duration)
            self._stats[(req_type, message_type)]["size"].append(
                    req_size + len(self._message))
            self._last = None

        self._ts = []

    def get_stats(self):
        return self._stats

    def get_requests(self):
        return self._requests


class Server:
    def __init__(self):
        self._conns = collections.defaultdict(lambda: Connection())

    def add_data(self, addr, data, direction, ts):
        self._conns[addr].add_data(data, direction, ts)

    def print_stats(self, machine_names, title, sort_by):
        stats = collections.defaultdict(lambda: collections.defaultdict(
                lambda: {"duration": [], "size": []}))

        for addr, conn in self._conns.items():
            ip, port = addr.split(":")
            for rpc, connstats in conn.get_stats().items():
                stats[ip][rpc]["duration"] += connstats["duration"]
                stats[ip][rpc]["size"] += connstats["size"]

        table = []
        headers = ["RPC ({})".format(title), "Client", "Count", "Tmin (ms)",
                   "Tavg (ms)", "Tmax (ms)", "Ttot (s)", "Smin (B)",
                   "Savg (B)", "Smax (B)", "Stot (kiB)"]
        sort_keys = ["rpc", "client", "count", "tmin", "tavg", "tmax", "ttot",
                     "smin", "savg", "smax", "stot"]
        for client in sorted(stats.keys()):
            rpcs = stats[client]
            for rpc, connstats in rpcs.items():
                durs = connstats["duration"]
                sizes = connstats["size"]
                durs_sum = sum(durs)
                sizes_sum = sum(sizes)
                table.append(["{} / {}".format(*rpc), machine_names[client],
                              len(durs), min(durs) * 1000,
                              durs_sum / len(durs) * 1000,
                              max(durs) * 1000, durs_sum, min(sizes),
                              int(sizes_sum / len(sizes)), max(sizes),
                              sizes_sum / 1024])
        for sort_field in sort_by.split(","):
            reverse = True if sort_field.endswith("-") else False
            table.sort(key=operator.itemgetter(sort_keys.index(
                    sort_field.rstrip("+-"))), reverse=reverse)
        print(tabulate.tabulate(table, headers=headers, tablefmt="psql",
                                floatfmt=".3f"))

    def get_requests(self, server_name, machine_names):
        ret = []
        for addr, conn in self._conns.items():
            client_name = machine_names[addr.split(":")[0]]
            for ts, direction, message, size in conn.get_requests():
                if direction == "from":
                    name_from, name_to = server_name, client_name
                else:
                    name_from, name_to = client_name, server_name
                ret.append((ts, name_from, name_to, message, size))
        return ret


# process logic

parser = argparse.ArgumentParser(description="Generate RPC statistics from "
                                 "network traffic capture.")
parser.add_argument("--sort-by", default="tavg+,count-,client+",
                    help="comma separated list of fields which should be used "
                    "to sort the data; each field can be suffixed with + or - "
                    "to indicate ascending or descending order; available "
                    "fields: rpc, client, count, min, avg, max, total")
parser.add_argument("--no-aggregate", action="store_true",
                    help="don't aggregate the results, instead display the "
                    "individual RPC calls as they occurred")
parser.add_argument("capfile", help="network traffic capture file")
parser.add_argument("conffile", help="cluster config file")
args = parser.parse_args()

config = json.load(open(args.conffile))
last_worker = 0
machine_names = {}
server_addresses = []
for machine in config["workload_machines"]:
    name = machine["type"]
    if name == "worker":
        last_worker += 1
        name += str(last_worker)
    machine_names["{address}".format(**machine)] = name
    server_addresses.append("{address}:{port}".format(**machine))

servers = collections.defaultdict(Server)

for ts, pkt in dpkt.pcap.Reader(open(args.capfile, "rb")):
    eth = dpkt.ethernet.Ethernet(pkt)
    if eth.type != dpkt.ethernet.ETH_TYPE_IP:
        continue

    ip = eth.data
    if ip.p != dpkt.ip.IP_PROTO_TCP:
        continue

    tcp = ip.data
    src = format_endpoint(ip.src, tcp.sport)
    dst = format_endpoint(ip.dst, tcp.dport)
    if src not in server_addresses and dst not in server_addresses:
        continue
    if len(tcp.data) == 0:
        continue

    server = dst if dst in server_addresses else src
    client = dst if dst not in server_addresses else src
    direction = "to" if dst in server_addresses else "from"

    servers[server].add_data(client, tcp.data, direction, ts)

requests = []
for server in sorted(servers.keys()):
    server_name = machine_names[server.split(":")[0]]
    if args.no_aggregate:
        requests.extend(servers[server].get_requests(server_name,
                                                     machine_names))
    else:
        servers[server].print_stats(machine_names=machine_names,
                                    title=server_name,
                                    sort_by=args.sort_by)

if args.no_aggregate:
    requests.sort()
    headers = ["timestamp", "from", "to", "request", "size"]
    print(tabulate.tabulate(requests, headers=headers, tablefmt="psql",
                            floatfmt=".6f"))