#!/usr/bin/env python3
#
# Copyright (C) Catalyst IT Ltd. 2019
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import sys
import argparse
import struct
import os
from collections import OrderedDict, Counter
from pprint import pprint

sys.path.insert(0, "bin/python")
import tdb


def unpack_uint(filename, casefold=True):
    db = tdb.Tdb(filename)
    d = {}
    for k in db:
        v = struct.unpack("I", db[k])[0]
        k2 = k.decode('utf-8')
        if casefold:
            k2 = k2.lower()
        if k2 in d: # because casefold
            d[k2] += v
        else:
            d[k2] = v
    return d


def unpack_ssize_t_pair(filename, casefold):
    db = tdb.Tdb(filename)
    pairs = []
    for k in db:
        key = struct.unpack("nn", k)
        v = struct.unpack("I", db[k])[0]
        pairs.append((v, key))

    pairs.sort(reverse=True)
    #print(pairs)
    return [(k, v) for (v, k) in pairs]


DATABASES = [
    ('requested', "debug/attr_counts_requested.tdb", unpack_uint,
     "The attribute was specifically requested."),
    ('duplicates', "debug/attr_counts_duplicates.tdb", unpack_uint,
     "Requested more than once in the same request."),
    ('empty request', "debug/attr_counts_empty_req.tdb", unpack_uint,
     "No attributes were requested, but these were returned"),
    ('null request', "debug/attr_counts_null_req.tdb", unpack_uint,
     "The attribute list was NULL and these were returned."),
    ('found', "debug/attr_counts_found.tdb", unpack_uint,
     "The attribute was specifically requested and it was found."),
    ('not found', "debug/attr_counts_not_found.tdb", unpack_uint,
     "The attribute was specifically requested but was not found."),
    ('unwanted', "debug/attr_counts_unwanted.tdb", unpack_uint,
     "The attribute was not requested and it was found."),
    ('star match', "debug/attr_counts_star_match.tdb", unpack_uint,
     'The attribute was not specifically requested but "*" was.'),
    ('req vs found', "debug/attr_counts_req_vs_found.tdb", unpack_ssize_t_pair,
     "How many attributes were requested versus how many were returned."),
]


def plot_pair_data(name, data, doc, lim=90):
    # Note we keep the matplotlib import internal to this function for
    # two reasons:
    # 1. Some people won't have matplotlib, but might want to run the
    #    script.
    # 2. The import takes hundreds of milliseconds, which is a
    #    nuisance if you don't want graphs.
    #
    # This plot could be improved!
    import matplotlib.pylab as plt
    fig, ax = plt.subplots()
    if lim:
        data2 = []
        for p, c in data:
            if p[0] > lim or p[1] > lim:
                print("not plotting %s: %s" % (p, c))
                continue
            data2.append((p, c))
        skipped = len(data) - len(data2)
        if skipped:
            name += " (excluding %d out of range values)" % skipped
            data = data2
    xy, counts = zip(*data)
    x, y = zip(*xy)
    bins_x = max(x) + 4
    bins_y = max(y)
    ax.set_title(name)
    ax.scatter(x, y, c=counts)
    plt.show()


def print_pair_data(name, data, doc):
    print(name)
    print(doc)
    t = "%14s | %14s | %14s"
    print(t % ("requested", "returned", "count"))
    print(t % (('-' * 14,) * 3))

    for xy, count in data:
        x, y = xy
        if x == -2:
            x = 'NULL'
        elif x == -4:
            x = '*'
        print(t % (x, y, count))


def print_counts(count_data):
    all_attrs = Counter()
    for c in count_data:
        all_attrs.update(c[1])

    print("found %d attrs" % len(all_attrs))
    longest = max(len(x) for x in all_attrs)

    #pprint(all_attrs)
    rows = OrderedDict()
    for a, _ in all_attrs.most_common():
        rows[a] = [a]

    for col_name, counts, doc in count_data:
        for attr, row in rows.items():
            d = counts.get(attr, '')
            row.append(d)

        print("%15s: %s" % (col_name, doc))
    print()

    t = "%{}s".format(longest)
    for c in count_data:
        t += " | %{}s".format(max(len(c[0]), 7))

    h = t % (("attribute",) + tuple(c[0] for c in count_data))
    print(h)
    print("-" * len(h))

    for attr, row in rows.items():
        print(t % tuple(row))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('LDB_PRIVATE_DIR',
                        help="read attr counts in this directory")
    parser.add_argument('--plot', action="store_true",
                        help='attempt to draw graphs')
    parser.add_argument('--no-casefold', action="store_false",
                        default=True, dest="casefold",
                        help='See all the encountered case variants')
    args = parser.parse_args()

    if not os.path.isdir(args.LDB_PRIVATE_DIR):
        parser.print_usage()
        sys.exit(1)

    count_data = []
    pair_data = []
    for k, filename, unpacker, doc in DATABASES:
        filename = os.path.join(args.LDB_PRIVATE_DIR, filename)
        try:
            d = unpacker(filename, casefold=args.casefold)
        except (RuntimeError, IOError) as e:
            print("could not parse %s: %s" % (filename, e))
            continue
        if unpacker is unpack_ssize_t_pair:
            pair_data.append((k, d, doc))
        else:
            count_data.append((k, d, doc))

    for k, v, doc in pair_data:
        if args.plot:
            plot_pair_data(k, v, doc)
        print_pair_data(k, v, doc)

    print()
    print_counts(count_data)

main()
