#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
RPC throughput barcharts generator (based on RPC benchmark JSON output).
To obtain data used for this script use:
`./tests/benchmark/rpc --benchmark_out=test.json --benchmark_out_format=json`
"""

import json
import collections
import os
import argparse
import numpy
import shutil

import matplotlib
# Must set "Agg" backend before importing pyplot
# This is so the script works on headless machines (without X11)
matplotlib.use("Agg")
import matplotlib.pyplot as plt


def parse_args():
    argp = argparse.ArgumentParser(description=__doc__)
    argp.add_argument("input", help="Load data from file.")
    argp.add_argument("output", help="Save plots to this directory.")
    argp.add_argument("--plot-width", type=int, default=1920,
                      help="Pixel width of generated plots.")
    argp.add_argument("--plot-height", type=int, default=1080,
                      help="Pixel height of generated plots.")
    argp.add_argument("--plot-dpi", type=int, default=96,
                      help="DPI of generated plots.")
    return argp.parse_args()


def humanize(num):
    suffix = ["", "k", "M", "G", "T", "P", "E", "Z", "Y"]
    pos = 0
    while num >= 1000.0:
        if pos == len(suffix) - 1:
            break
        num /= 1000
        pos += 1
    return str(int(round(num, 0))) + suffix[pos]


def dehumanize(num):
    pos = -1
    suffix = ["k", "M", "G", "T", "P", "E", "Z", "Y"]
    for index, suff in enumerate(suffix):
        if num.endswith(suff):
            pos = index
            num = num[:-1]
            break
    return int(num) * 1000 ** (pos + 1)


def autolabel(ax, rects):
    for rect in rects:
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width()/2., 1.00*height,
                humanize(height),
                ha="center", va="bottom")


def generate_plot(size, results, plot_width, plot_height, plot_dpi,
                  output_file):
    # Font size.
    plt.rc("font", size=10)
    plt.rc("axes", titlesize=24)
    plt.rc("axes", labelsize=16)
    plt.rc("xtick", labelsize=12)
    plt.rc("ytick", labelsize=12)
    plt.rc("legend", fontsize=16)
    plt.rc("figure", titlesize=24)

    groups = sorted(results.keys())
    categories = list(map(lambda x: x[0], results[groups[0]]))

    # Plot.
    ind = numpy.arange(len(groups))
    width = 0.10
    fig, ax = plt.subplots()
    fig.set_size_inches(plot_width / plot_dpi,
                        plot_height / plot_dpi)
    ax.set_xlabel("Concurrent threads")
    ax.set_ylabel("Throughput (call/s)")
    ax.set_facecolor("#dcdcdc")
    ax.set_xticks(ind + width / len(categories))
    ax.set_xticklabels(groups)
    for line in ax.get_xgridlines():
        line.set_linestyle(" ")
    for line in ax.get_ygridlines():
        line.set_linestyle("--")
    ax.set_axisbelow(True)
    plt.grid(True)
    ax.set_title("RPC throughput (size: {})".format(size))

    # Draw bars.
    rects = []
    for index, category in enumerate(categories):
        category_results = [results[group][index][1] for group in groups]
        rect = ax.bar(ind + index * width, category_results, width)
        rects.append(rect)
        autolabel(ax, rect)
    ax.legend(rects, categories)

    # Plot output.
    plt.savefig(output_file, dpi=plot_dpi)


def main():
    # Read the arguments.
    args = parse_args()

    # Load data.
    with open(args.input) as f:
        data = json.load(f)

    # Process data.
    results = collections.defaultdict(lambda: collections.defaultdict(list))
    for benchmark in data["benchmarks"]:
        info = benchmark["name"].split("/")
        name, size = info[0:2]
        threads = int(info[-1].split(":")[1])
        throughput = benchmark["items_per_second"]
        results[size][threads].append((name, throughput))

    # Cleanup output directory.
    if os.path.isdir(args.output):
        shutil.rmtree(args.output)
    if os.path.exists(args.output):
        os.remove(args.output)
    os.mkdir(args.output)

    # Generate plots.
    sizes = sorted(results.keys(), key=dehumanize)
    for prefix, size in enumerate(sizes):
        result = results[size]
        output_file = os.path.join(args.output,
                                   "{}_{}.png".format(prefix, size))
        generate_plot(size, result, args.plot_width, args.plot_height,
                      args.plot_dpi, output_file)


if __name__ == "__main__":
    main()