#!/usr/bin/env python3 import argparse import collections import dpkt import json import operator import os import re import socket import subprocess import struct import sys import tabulate SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..")) TYPE_INFO_PATTERN = re.compile(r""" const\s+utils::TypeInfo\s+ (\S+)\s* # <class-name>::kType {\s* (\S+)ULL\s*,\s* # <id-hex> "(\S+)"\s* # "<class-name>" } """, re.VERBOSE) # helpers def format_endpoint(addr, port): return "{}:{}".format(socket.inet_ntoa(addr), port) def parse_source_file(fname): ret = {} with open(fname) as f: for match in TYPE_INFO_PATTERN.finditer(f.read()): id_hex = int(match.groups()[1], 16) class_name = match.groups()[2] ret[id_hex] = class_name return ret def parse_all_source_files(dirname): ids = {} ret = subprocess.run(["find", dirname, "-name", "*.lcp.cpp"], stdout=subprocess.PIPE) ret.check_returncode() headers = list(filter(None, ret.stdout.decode("utf-8").split("\n"))) for header in headers: ids.update(parse_source_file(header)) return ids MESSAGES = parse_all_source_files(os.path.join(PROJECT_DIR, "src")) class Connection: # uint32_t message_size SIZE_FORMAT = "I" SIZE_LEN = struct.calcsize(SIZE_FORMAT) def __init__(self): self._previous = bytes() self._data = bytes() self._message = bytes() self._ts = [] self._last = None self._stats = collections.defaultdict(lambda: {"duration": [], "size": []}) self._requests = [] def _extract_message(self): if len(self._data) < self.SIZE_LEN: return False msg_len = struct.unpack_from(self.SIZE_FORMAT, self._data)[0] if len(self._data) < self.SIZE_LEN + msg_len: return False self._message = self._data[self.SIZE_LEN:] self._previous = self._data self._data = bytes() return True def add_data(self, data, direction, ts): if self._previous[-len(data):] == data \ or self._data[-len(data):] == data: print("Retransmission detected!", file=sys.stderr) return self._data += data self._ts.append(ts) if not self._extract_message(): return found = False for i in range(2, 6): if len(self._message) < (i + 1) * 8: continue message_id = struct.unpack("<Q", self._message[i * 8:(i + 1) * 8])[0] if message_id in MESSAGES: found = True break if not found: print("Got a message that I can't identify as any known " "RPC request/response!", file=sys.stderr) self._last = None self._ts = [] return message_type = MESSAGES[message_id] if direction == "to": self._requests.append((self._ts[-1], direction, message_type, len(self._message))) else: self._requests.append((self._ts[0], direction, message_type, len(self._message))) if self._last is None: self._last = (message_type, self._ts[0], len(self._message)) else: req_type, req_ts, req_size = self._last duration = self._ts[-1] - req_ts self._stats[(req_type, message_type)]["duration"].append(duration) self._stats[(req_type, message_type)]["size"].append( req_size + len(self._message)) self._last = None self._ts = [] def get_stats(self): return self._stats def get_requests(self): return self._requests class Server: def __init__(self): self._conns = collections.defaultdict(lambda: Connection()) def add_data(self, addr, data, direction, ts): self._conns[addr].add_data(data, direction, ts) def print_stats(self, machine_names, title, sort_by): stats = collections.defaultdict(lambda: collections.defaultdict( lambda: {"duration": [], "size": []})) for addr, conn in self._conns.items(): ip, port = addr.split(":") for rpc, connstats in conn.get_stats().items(): stats[ip][rpc]["duration"] += connstats["duration"] stats[ip][rpc]["size"] += connstats["size"] table = [] headers = ["RPC ({})".format(title), "Client", "Count", "Tmin (ms)", "Tavg (ms)", "Tmax (ms)", "Ttot (s)", "Smin (B)", "Savg (B)", "Smax (B)", "Stot (kiB)"] sort_keys = ["rpc", "client", "count", "tmin", "tavg", "tmax", "ttot", "smin", "savg", "smax", "stot"] for client in sorted(stats.keys()): rpcs = stats[client] for rpc, connstats in rpcs.items(): durs = connstats["duration"] sizes = connstats["size"] durs_sum = sum(durs) sizes_sum = sum(sizes) table.append(["{} / {}".format(*rpc), machine_names[client], len(durs), min(durs) * 1000, durs_sum / len(durs) * 1000, max(durs) * 1000, durs_sum, min(sizes), int(sizes_sum / len(sizes)), max(sizes), sizes_sum / 1024]) for sort_field in sort_by.split(","): reverse = True if sort_field.endswith("-") else False table.sort(key=operator.itemgetter(sort_keys.index( sort_field.rstrip("+-"))), reverse=reverse) print(tabulate.tabulate(table, headers=headers, tablefmt="psql", floatfmt=".3f")) def get_requests(self, server_name, machine_names): ret = [] for addr, conn in self._conns.items(): client_name = machine_names[addr.split(":")[0]] for ts, direction, message, size in conn.get_requests(): if direction == "from": name_from, name_to = server_name, client_name else: name_from, name_to = client_name, server_name ret.append((ts, name_from, name_to, message, size)) return ret # process logic parser = argparse.ArgumentParser(description="Generate RPC statistics from " "network traffic capture.") parser.add_argument("--sort-by", default="tavg+,count-,client+", help="comma separated list of fields which should be used " "to sort the data; each field can be suffixed with + or - " "to indicate ascending or descending order; available " "fields: rpc, client, count, min, avg, max, total") parser.add_argument("--no-aggregate", action="store_true", help="don't aggregate the results, instead display the " "individual RPC calls as they occurred") parser.add_argument("capfile", help="network traffic capture file") parser.add_argument("conffile", help="cluster config file") args = parser.parse_args() config = json.load(open(args.conffile)) last_worker = 0 machine_names = {} server_addresses = [] for machine in config["workload_machines"]: name = machine["type"] if name == "worker": last_worker += 1 name += str(last_worker) machine_names["{address}".format(**machine)] = name server_addresses.append("{address}:{port}".format(**machine)) servers = collections.defaultdict(Server) for ts, pkt in dpkt.pcap.Reader(open(args.capfile, "rb")): eth = dpkt.ethernet.Ethernet(pkt) if eth.type != dpkt.ethernet.ETH_TYPE_IP: continue ip = eth.data if ip.p != dpkt.ip.IP_PROTO_TCP: continue tcp = ip.data src = format_endpoint(ip.src, tcp.sport) dst = format_endpoint(ip.dst, tcp.dport) if src not in server_addresses and dst not in server_addresses: continue if len(tcp.data) == 0: continue server = dst if dst in server_addresses else src client = dst if dst not in server_addresses else src direction = "to" if dst in server_addresses else "from" servers[server].add_data(client, tcp.data, direction, ts) requests = [] for server in sorted(servers.keys()): server_name = machine_names[server.split(":")[0]] if args.no_aggregate: requests.extend(servers[server].get_requests(server_name, machine_names)) else: servers[server].print_stats(machine_names=machine_names, title=server_name, sort_by=args.sort_by) if args.no_aggregate: requests.sort() headers = ["timestamp", "from", "to", "request", "size"] print(tabulate.tabulate(requests, headers=headers, tablefmt="psql", floatfmt=".6f"))