memgraph/tools/analyze_rpc_calls

269 lines
9.4 KiB
Plaintext
Raw Normal View History

#!/usr/bin/env python3
import argparse
import collections
import dpkt
import json
import operator
import os
import socket
import subprocess
import struct
import sys
import tabulate
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, ".."))
# helpers
def format_endpoint(addr, port):
return "{}:{}".format(socket.inet_ntoa(addr), port)
def parse_capnp_header(fname):
ret = {}
last_struct = ""
with open(fname) as f:
for row in f:
row = row.strip()
if row.startswith("struct") and \
not row.startswith("struct _capnpPrivate"):
last_struct = row.split()[1]
if row.startswith("CAPNP_DECLARE_STRUCT_HEADER"):
bytes_val = bytes.fromhex(row.split("(")[1].split(",")[0])
val = struct.unpack(">Q", bytes_val)[0]
ret[val] = last_struct
return ret
# TODO(mferencevic): Update this to parse .cpp files (99% are .lcp.cpp),
# containing the following line.
#
# const utils::TypeInfo <class-name>::kType{<id-hex>, "<class-name>"};
#
# Note that clang-format may break the line at any of the spaces or after '{'.
def parse_all_capnp_headers(dirname):
ids = {}
ret = subprocess.run(["find", dirname, "-name", "*.capnp.h"],
stdout=subprocess.PIPE)
ret.check_returncode()
headers = list(filter(None, ret.stdout.decode("utf-8").split("\n")))
for header in headers:
ids.update(parse_capnp_header(header))
return ids
MESSAGES = parse_all_capnp_headers(os.path.join(PROJECT_DIR, "src"))
class Connection:
# uint32_t message_size
SIZE_FORMAT = "I"
SIZE_LEN = struct.calcsize(SIZE_FORMAT)
def __init__(self):
self._previous = bytes()
self._data = bytes()
self._message = bytes()
self._ts = []
self._last = None
self._stats = collections.defaultdict(lambda: {"duration": [],
"size": []})
self._requests = []
def _extract_message(self):
if len(self._data) < self.SIZE_LEN:
return False
msg_len = struct.unpack_from(self.SIZE_FORMAT, self._data)[0]
if len(self._data) < self.SIZE_LEN + msg_len:
return False
self._message = self._data[self.SIZE_LEN:]
self._previous = self._data
self._data = bytes()
return True
def add_data(self, data, direction, ts):
if self._previous[-len(data):] == data \
or self._data[-len(data):] == data:
print("Retransmission detected!", file=sys.stderr)
return
self._data += data
self._ts.append(ts)
if not self._extract_message():
return
found = False
for i in range(2, 6):
if len(self._message) < (i + 1) * 8:
continue
message_id = struct.unpack("<Q",
self._message[i * 8:(i + 1) * 8])[0]
if message_id in MESSAGES:
found = True
break
if not found:
print("Got a message that I can't identify as any known "
"RPC request/response!", file=sys.stderr)
self._last = None
self._ts = []
return
message_type = MESSAGES[message_id]
if direction == "to":
self._requests.append((self._ts[-1], direction, message_type,
len(self._message)))
else:
self._requests.append((self._ts[0], direction, message_type,
len(self._message)))
if self._last is None:
self._last = (message_type, self._ts[0], len(self._message))
else:
req_type, req_ts, req_size = self._last
duration = self._ts[-1] - req_ts
self._stats[(req_type, message_type)]["duration"].append(duration)
self._stats[(req_type, message_type)]["size"].append(
req_size + len(self._message))
self._last = None
self._ts = []
def get_stats(self):
return self._stats
def get_requests(self):
return self._requests
class Server:
def __init__(self):
self._conns = collections.defaultdict(lambda: Connection())
def add_data(self, addr, data, direction, ts):
self._conns[addr].add_data(data, direction, ts)
def print_stats(self, machine_names, title, sort_by):
stats = collections.defaultdict(lambda: collections.defaultdict(
lambda: {"duration": [], "size": []}))
for addr, conn in self._conns.items():
ip, port = addr.split(":")
for rpc, connstats in conn.get_stats().items():
stats[ip][rpc]["duration"] += connstats["duration"]
stats[ip][rpc]["size"] += connstats["size"]
table = []
headers = ["RPC ({})".format(title), "Client", "Count", "Tmin (ms)",
"Tavg (ms)", "Tmax (ms)", "Ttot (s)", "Smin (B)",
"Savg (B)", "Smax (B)", "Stot (kiB)"]
sort_keys = ["rpc", "client", "count", "tmin", "tavg", "tmax", "ttot",
"smin", "savg", "smax", "stot"]
for client in sorted(stats.keys()):
rpcs = stats[client]
for rpc, connstats in rpcs.items():
durs = connstats["duration"]
sizes = connstats["size"]
durs_sum = sum(durs)
sizes_sum = sum(sizes)
table.append(["{} / {}".format(*rpc), machine_names[client],
len(durs), min(durs) * 1000,
durs_sum / len(durs) * 1000,
max(durs) * 1000, durs_sum, min(sizes),
int(sizes_sum / len(sizes)), max(sizes),
sizes_sum / 1024])
for sort_field in sort_by.split(","):
reverse = True if sort_field.endswith("-") else False
table.sort(key=operator.itemgetter(sort_keys.index(
sort_field.rstrip("+-"))), reverse=reverse)
print(tabulate.tabulate(table, headers=headers, tablefmt="psql",
floatfmt=".3f"))
def get_requests(self, server_name, machine_names):
ret = []
for addr, conn in self._conns.items():
client_name = machine_names[addr.split(":")[0]]
for ts, direction, message, size in conn.get_requests():
if direction == "from":
name_from, name_to = server_name, client_name
else:
name_from, name_to = client_name, server_name
ret.append((ts, name_from, name_to, message, size))
return ret
# process logic
parser = argparse.ArgumentParser(description="Generate RPC statistics from "
"network traffic capture.")
parser.add_argument("--sort-by", default="tavg+,count-,client+",
help="comma separated list of fields which should be used "
"to sort the data; each field can be suffixed with + or - "
"to indicate ascending or descending order; available "
"fields: rpc, client, count, min, avg, max, total")
parser.add_argument("--no-aggregate", action="store_true",
help="don't aggregate the results, instead display the "
"individual RPC calls as they occurred")
parser.add_argument("capfile", help="network traffic capture file")
parser.add_argument("conffile", help="cluster config file")
args = parser.parse_args()
config = json.load(open(args.conffile))
last_worker = 0
machine_names = {}
server_addresses = []
for machine in config["workload_machines"]:
name = machine["type"]
if name == "worker":
last_worker += 1
name += str(last_worker)
machine_names["{address}".format(**machine)] = name
server_addresses.append("{address}:{port}".format(**machine))
servers = collections.defaultdict(Server)
for ts, pkt in dpkt.pcap.Reader(open(args.capfile, "rb")):
eth = dpkt.ethernet.Ethernet(pkt)
if eth.type != dpkt.ethernet.ETH_TYPE_IP:
continue
ip = eth.data
if ip.p != dpkt.ip.IP_PROTO_TCP:
continue
tcp = ip.data
src = format_endpoint(ip.src, tcp.sport)
dst = format_endpoint(ip.dst, tcp.dport)
if src not in server_addresses and dst not in server_addresses:
continue
if len(tcp.data) == 0:
continue
server = dst if dst in server_addresses else src
client = dst if dst not in server_addresses else src
direction = "to" if dst in server_addresses else "from"
servers[server].add_data(client, tcp.data, direction, ts)
requests = []
for server in sorted(servers.keys()):
server_name = machine_names[server.split(":")[0]]
if args.no_aggregate:
requests.extend(servers[server].get_requests(server_name,
machine_names))
else:
servers[server].print_stats(machine_names=machine_names,
title=server_name,
sort_by=args.sort_by)
if args.no_aggregate:
requests.sort()
headers = ["timestamp", "from", "to", "request", "size"]
print(tabulate.tabulate(requests, headers=headers, tablefmt="psql",
floatfmt=".6f"))