From e7cde4b4ef058066a984c6c3c5aa2e6cb6a3aae2 Mon Sep 17 00:00:00 2001
From: Matej Ferencevic <matej.ferencevic@memgraph.io>
Date: Thu, 23 Aug 2018 12:00:48 +0200
Subject: [PATCH] Make analyze_rpc_calls compatible with Cap'n Proto

Reviewers: teon.banek, buda

Reviewed By: teon.banek

Subscribers: pullbot

Differential Revision: https://phabricator.memgraph.io/D1561
---
 tools/analyze_rpc_calls | 170 +++++++++++++++++++++++++---------------
 1 file changed, 107 insertions(+), 63 deletions(-)

diff --git a/tools/analyze_rpc_calls b/tools/analyze_rpc_calls
index 015ca7be2..a6008bf90 100755
--- a/tools/analyze_rpc_calls
+++ b/tools/analyze_rpc_calls
@@ -4,80 +4,93 @@ import collections
 import dpkt
 import json
 import operator
+import os
 import socket
+import subprocess
 import struct
 import tabulate
 
+SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
+PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, ".."))
+
+
 # helpers
 
 def format_endpoint(addr, port):
     return "{}:{}".format(socket.inet_ntoa(addr), port)
 
 
+def parse_capnp_header(fname):
+    ret = {}
+    last_struct = ""
+    with open(fname) as f:
+        for row in f:
+            row = row.strip()
+            if row.startswith("struct") and \
+                    not row.startswith("struct _capnpPrivate"):
+                last_struct = row.split()[1]
+            if row.startswith("CAPNP_DECLARE_STRUCT_HEADER"):
+                bytes_val = bytes.fromhex(row.split("(")[1].split(",")[0])
+                val = struct.unpack(">Q", bytes_val)[0]
+                ret[val] = last_struct
+    return ret
+
+
+def parse_all_capnp_headers(dirname):
+    ids = {}
+    ret = subprocess.run(["find", dirname, "-name", "*.capnp.h"],
+                         stdout=subprocess.PIPE)
+    ret.check_returncode()
+    headers = list(filter(None, ret.stdout.decode("utf-8").split("\n")))
+    for header in headers:
+        ids.update(parse_capnp_header(header))
+    return ids
+
+
+MESSAGES = parse_all_capnp_headers(os.path.join(PROJECT_DIR, "src"))
+
+
 class Connection:
-    # uint32_t channel_size
-    HEADER_FORMAT = "I"
-
-    # uint32_t message_number, uint32_t message_size
-    MESSAGE_FORMAT = "II"
-
-    # 8: boost archive string size
-    # 22: boost archive fixed string "serialization::archive"
-    # 17: boost archive magic bytes
-    BOOST_OFFSET = 8 + 22 + 17
+    # uint32_t message_size
+    SIZE_FORMAT = "I"
+    SIZE_LEN = struct.calcsize(SIZE_FORMAT)
 
     def __init__(self):
-        self._handshake_done = False
-        self._channel = ""
-
         self._data = bytes()
         self._message = bytes()
         self._ts = []
 
-        self._header_format_len = struct.calcsize(self.HEADER_FORMAT)
-        self._message_format_len = struct.calcsize(self.MESSAGE_FORMAT)
-
         self._last = None
-        self._stats = collections.defaultdict(lambda:
-                {"duration": [], "size": []})
-
-    def _extract_channel(self):
-        if len(self._data) < self._header_format_len:
-            return False
-        chan_len = struct.unpack_from(self.HEADER_FORMAT, self._data)[0]
-        if len(self._data) < self._header_format_len + chan_len:
-            return False
-        self._channel = self._data[self._header_format_len:].decode("utf-8")
-        self._data = bytes()
-        return True
+        self._stats = collections.defaultdict(lambda: {"duration": [],
+                                                       "size": []})
+        self._requests = []
 
     def _extract_message(self):
-        if len(self._data) < self._message_format_len:
+        if len(self._data) < self.SIZE_LEN:
             return False
-        msg_num, msg_len = struct.unpack_from("IH", self._data)
-        if len(self._data) < self._message_format_len + msg_len:
+        msg_len = struct.unpack_from(self.SIZE_FORMAT, self._data)[0]
+        if len(self._data) < self.SIZE_LEN + msg_len:
             return False
-        self._message = self._data[self._message_format_len:]
+        self._message = self._data[self.SIZE_LEN:]
         self._data = bytes()
         return True
 
-    def add_data(self, data, ts):
+    def add_data(self, data, direction, ts):
         self._data += data
         self._ts.append(ts)
 
-        if not self._handshake_done:
-            if not self._extract_channel():
-                return
-            self._handshake_done = True
-            self._ts = []
-
         if not self._extract_message():
             return
 
-        message_type_size = struct.unpack_from("Q", self._message,
-                self.BOOST_OFFSET)[0]
-        message_type = struct.unpack_from("{}s".format(message_type_size),
-                self._message, self.BOOST_OFFSET + 8)[0].decode("utf-8")
+        message_id = struct.unpack("<Q", self._message[16:24])[0]
+        message_type = MESSAGES[message_id]
+
+        if direction == "to":
+            self._requests.append((self._ts[-1], direction, message_type,
+                                   len(self._message)))
+        else:
+            self._requests.append((self._ts[0], direction, message_type,
+                                   len(self._message)))
 
         if self._last is None:
             self._last = (message_type, self._ts[0])
@@ -94,13 +107,16 @@ class Connection:
     def get_stats(self):
         return self._stats
 
+    def get_requests(self):
+        return self._requests
+
 
 class Server:
     def __init__(self):
         self._conns = collections.defaultdict(lambda: Connection())
 
-    def add_data(self, addr, data, ts):
-        self._conns[addr].add_data(data, ts)
+    def add_data(self, addr, data, direction, ts):
+        self._conns[addr].add_data(data, direction, ts)
 
     def print_stats(self, machine_names, title, sort_by):
         stats = collections.defaultdict(lambda: collections.defaultdict(
@@ -114,10 +130,10 @@ class Server:
 
         table = []
         headers = ["RPC ({})".format(title), "Client", "Count", "Tmin (ms)",
-                "Tavg (ms)", "Tmax (ms)", "Ttot (s)", "Smin (B)",
-                "Savg (B)", "Smax (B)", "Stot (kiB)"]
+                   "Tavg (ms)", "Tmax (ms)", "Ttot (s)", "Smin (B)",
+                   "Savg (B)", "Smax (B)", "Stot (kiB)"]
         sort_keys = ["rpc", "client", "count", "tmin", "tavg", "tmax", "ttot",
-                "smin", "savg", "smax", "stot"]
+                     "smin", "savg", "smax", "stot"]
         for client in sorted(stats.keys()):
             rpcs = stats[client]
             for rpc, connstats in rpcs.items():
@@ -126,28 +142,43 @@ class Server:
                 durs_sum = sum(durs)
                 sizes_sum = sum(sizes)
                 table.append(["{} / {}".format(*rpc), machine_names[client],
-                        len(durs), min(durs) * 1000,
-                        durs_sum / len(durs) * 1000,
-                        max(durs) * 1000, durs_sum, min(sizes),
-                        int(sizes_sum / len(sizes)), max(sizes),
-                        sizes_sum / 1024])
+                              len(durs), min(durs) * 1000,
+                              durs_sum / len(durs) * 1000,
+                              max(durs) * 1000, durs_sum, min(sizes),
+                              int(sizes_sum / len(sizes)), max(sizes),
+                              sizes_sum / 1024])
         for sort_field in sort_by.split(","):
             reverse = True if sort_field.endswith("-") else False
             table.sort(key=operator.itemgetter(sort_keys.index(
                     sort_field.rstrip("+-"))), reverse=reverse)
         print(tabulate.tabulate(table, headers=headers, tablefmt="psql",
-                floatfmt=".3f"))
+                                floatfmt=".3f"))
+
+    def get_requests(self, server_name, machine_names):
+        ret = []
+        for addr, conn in self._conns.items():
+            client_name = machine_names[addr.split(":")[0]]
+            for ts, direction, message, size in conn.get_requests():
+                if direction == "from":
+                    name_from, name_to = server_name, client_name
+                else:
+                    name_from, name_to = client_name, server_name
+                ret.append((ts, name_from, name_to, message, size))
+        return ret
 
 
 # process logic
 
 parser = argparse.ArgumentParser(description="Generate RPC statistics from "
-        "network traffic capture.")
+                                 "network traffic capture.")
 parser.add_argument("--sort-by", default="tavg+,count-,client+",
-        help="comma separated list of fields which should be used to sort "
-        "the data; each field can be suffixed with + or - to indicate "
-        "ascending or descending order; available fields: rpc, client, "
-        "count, min, avg, max, total")
+                    help="comma separated list of fields which should be used "
+                    "to sort the data; each field can be suffixed with + or - "
+                    "to indicate ascending or descending order; available "
+                    "fields: rpc, client, count, min, avg, max, total")
+parser.add_argument("--no-aggregate", action="store_true",
+                    help="don't aggregate the results, instead display the "
+                    "individual RPC calls as they occurred")
 parser.add_argument("capfile", help="network traffic capture file")
 parser.add_argument("conffile", help="cluster config file")
 args = parser.parse_args()
@@ -185,10 +216,23 @@ for ts, pkt in dpkt.pcap.Reader(open(args.capfile, "rb")):
 
     server = dst if dst in server_addresses else src
     client = dst if dst not in server_addresses else src
+    direction = "to" if dst in server_addresses else "from"
 
-    servers[server].add_data(client, tcp.data, ts)
+    servers[server].add_data(client, tcp.data, direction, ts)
 
+requests = []
 for server in sorted(servers.keys()):
-    servers[server].print_stats(machine_names=machine_names,
-            title=machine_names[server.split(":")[0]],
-            sort_by=args.sort_by)
+    server_name = machine_names[server.split(":")[0]]
+    if args.no_aggregate:
+        requests.extend(servers[server].get_requests(server_name,
+                                                     machine_names))
+    else:
+        servers[server].print_stats(machine_names=machine_names,
+                                    title=server_name,
+                                    sort_by=args.sort_by)
+
+if args.no_aggregate:
+    requests.sort()
+    headers = ["timestamp", "from", "to", "request", "size"]
+    print(tabulate.tabulate(requests, headers=headers, tablefmt="psql",
+                            floatfmt=".6f"))