#!/usr/bin/env python3
# -*- coding: utf-8 -*-

'''
Latency Barchart (Based on LDBC JSON output).
'''

import json
import os
import numpy as np
from argparse import ArgumentParser
import string

import matplotlib
# Must set 'Agg' backend before importing pyplot
# This is so the script works on headless machines (without X11)
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.cbook import get_sample_data


SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
COLORS = ['#ff7300', '#008cc2']
LDBC_TIME_FACTORS = {
    "SECONDS": 1.0,
    "MILLISECONDS": 1000.0,
    "MICROSECONDS": 1000000.0,
    "NANOSECONDS": 1000000000.0
}
TIME_FACTORS = {
    "s": 1.0,
    "ms": 1000,
    "us": 1000000,
    "ns": 1000000000,
}


def parse_args():
    argp = ArgumentParser(description=__doc__)
    argp.add_argument("--vendor-titles", nargs="+",
                      default=["Memgraph", "Market leader"],
                      help="Vender titles that are going to appear "
                           "on the plot, e.g. legend titles.")
    argp.add_argument("--plot-title", default="",
                      help="Plot title.")
    argp.add_argument("--logo-path", default=None,
                      help="Path to the logo that is going to be presented"
                           " instead of title.")
    argp.add_argument("--results", nargs="+", required=True,
                      help="Path to the folder with result files in format "
                           "{{vendor-reference}}-LDBC-results.json")
    argp.add_argument("--time-unit", choices=("s", "ms", "us", "ns"),
                      default="ms", help="The time unit that should be used.")
    argp.add_argument("--output", default="",
                      help="Save plot to file (instead of displaying it).")
    argp.add_argument("--plot-width", type=int, default=1920,
                      help="Pixel width of generated plots.")
    argp.add_argument("--plot-height", type=int, default=1080,
                      help="Pixel height of generated plots.")
    argp.add_argument("--dpi", type=int, default=96,
                      help="DPI of generated plots.")
    return argp.parse_args()


def autolabel(ax, rects):
    """
    Attach a text label above each bar displaying its height
    """
    for rect in rects:
        height = rect.get_height()
        # TODO: adjust more vendors
        ax.text(rect.get_x() + rect.get_width()/2., 1.00*height,
                '{:.1f}'.format(height),
                ha='center', va='bottom')


def main():
    # Read the arguments.
    args = parse_args()

    # Prepare the datastructure.
    vendors = []
    for i, results_file, vendor_title in zip(range(len(args.results)),
                                             args.results,
                                             args.vendor_titles):
        vendor = {}
        vendor['title'] = vendor_title
        vendor['results_file'] = results_file
        vendor['color'] = COLORS[i]
        vendor['results'] = []
        vendors.append(vendor)
    assert len(vendors) == 2, "The graph is tailored for only 2 vendors."

    # Helper for shortening the query name.
    def shorten_query_name(name):
        if name.lower().startswith("ldbc"):
            name = name[4:]
        # Long query names on the x-axis don't look compelling.
        num = "".join(filter(lambda x: x in string.digits, name))
        prefix = name.split(num)[0]
        return prefix + num

    # Collect the benchmark data.
    print("LDBC Latency Data")
    for vendor in vendors:
        with open(vendor['results_file']) as results_file:
            results_data = json.load(results_file)
            for query_data in results_data["all_metrics"]:
                mean_runtime = (query_data["run_time"]["mean"] /
                                LDBC_TIME_FACTORS[results_data["unit"]] *
                                TIME_FACTORS[args.time_unit])
                query_name = shorten_query_name(query_data['name'])
                vendor['results'].append((query_name, mean_runtime))

    # Helper for sorting the results.
    def sort_key(obj):
        name = obj[0]
        num = int("".join(filter(lambda x: x in string.digits, name)))
        prefix = name.split(str(num))[0]
        return (prefix, num)

    # Sort results.
    for vendor in vendors:
        vendor['results'].sort(key=sort_key)

    # Print results.
    for vendor in vendors:
        print("Vendor:", vendor['title'])
        for query_name, latency in vendor['results']:
            print("{} -> {:.3f}{}".format(query_name, latency, args.time_unit))

    # Consistency check.
    all_query_names = [tuple(res[0] for res in vd['results']) for vd in vendors]
    assert len(set(all_query_names)) == 1, \
        "Queries between different vendors are different!"
    query_names = all_query_names[0]

    # Font size.
    plt.rc('font', size=12)          # controls default text sizes
    plt.rc('axes', titlesize=24)     # fontsize of the axes title
    plt.rc('axes', labelsize=16)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=12)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=12)    # fontsize of the tick labels
    plt.rc('legend', fontsize=16)    # legend fontsize
    plt.rc('figure', titlesize=24)  # fontsize of the figure title

    # Plot.
    ind = np.arange(len(query_names))   # the x locations for the groups
    width = 0.40                        # the width of the bars
    fig, ax = plt.subplots()            # figure setup
    fig.set_size_inches(args.plot_width / args.dpi,
                        args.plot_height / args.dpi)  # set figure size
    ax.set_ylabel('Mean Latency (%s)' % (args.time_unit))  # YAxis title
    ax.set_facecolor('#dcdcdc')         # plot bg color (light gray)
    ax.set_xticks(ind + width / len(vendors))  # TODO: adjust (more vendors)
    ax.set_xticklabels(query_names, rotation=30)
    # set only horizontal grid lines
    for line in ax.get_xgridlines():
        line.set_linestyle(' ')
    for line in ax.get_ygridlines():
        line.set_linestyle('--')
    ax.set_axisbelow(True)              # put the grid below all other elements
    plt.grid(True)                      # show grid
    # Set plot title
    ax.set_title(args.plot_title)
    # Draw logo or plot title
    if args.logo_path is not None:
        im = plt.imread(get_sample_data(os.path.join(os.getcwd(),
                                                     args.logo_path)))
        plt.gcf().subplots_adjust(top=0.85)
        # magic numbers for logo size - DO NOT TOUCH!
        newax = fig.add_axes([0.46, 0.85, 0.126, 0.15], anchor='N')
        newax.imshow(im)
        newax.axis('off')
    # Draw bars
    for index, vendor in enumerate(vendors):
        latencies = [res[1] for res in vendor['results']]
        rects = ax.bar(ind + index * width, latencies, width,
                       color=vendor['color'])
        vendor['rects'] = rects
        autolabel(ax, rects)
    rects = [vd['rects'][0] for vd in vendors]
    titles = [vd['title'] for vd in vendors]
    ax.legend(rects, titles)           # Draw the legend.

    if args.output == "":
        plt.show()
    else:
        plt.savefig(args.output, dpi=args.dpi)


if __name__ == '__main__':
    main()