#!/usr/bin/python3
#
# Examine build performance test results
#
# Copyright (c) 2017, Intel Corporation.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms and conditions of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
import argparse
import json
import logging
import os
import re
import sys
from collections import namedtuple, OrderedDict
from operator import attrgetter
from xml.etree import ElementTree as ET

# Import oe libs
scripts_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(scripts_path, 'lib'))
import scriptpath
from build_perf import print_table
from build_perf.report import (metadata_xml_to_json, results_xml_to_json,
                               aggregate_data, aggregate_metadata, measurement_stats)
from build_perf import html

scriptpath.add_oe_lib_path()

from oeqa.utils.git import GitRepo, GitError


# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger('oe-build-perf-report')


# Container class for tester revisions
TestedRev = namedtuple('TestedRev', 'commit commit_number tags')


def get_test_runs(repo, tag_name, **kwargs):
    """Get a sorted list of test runs, matching given pattern"""
    # First, get field names from the tag name pattern
    field_names = [m.group(1) for m in re.finditer(r'{(\w+)}', tag_name)]
    undef_fields = [f for f in field_names if f not in kwargs.keys()]

    # Fields for formatting tag name pattern
    str_fields = dict([(f, '*') for f in field_names])
    str_fields.update(kwargs)

    # Get a list of all matching tags
    tag_pattern = tag_name.format(**str_fields)
    tags = repo.run_cmd(['tag', '-l', tag_pattern]).splitlines()
    log.debug("Found %d tags matching pattern '%s'", len(tags), tag_pattern)

    # Parse undefined fields from tag names
    str_fields = dict([(f, r'(?P<{}>[\w\-.()]+)'.format(f)) for f in field_names])
    str_fields['branch'] = r'(?P<branch>[\w\-.()/]+)'
    str_fields['commit'] = '(?P<commit>[0-9a-f]{7,40})'
    str_fields['commit_number'] = '(?P<commit_number>[0-9]{1,7})'
    str_fields['tag_number'] = '(?P<tag_number>[0-9]{1,5})'
    # escape parenthesis in fields in order to not messa up the regexp
    fixed_fields = dict([(k, v.replace('(', r'\(').replace(')', r'\)')) for k, v in kwargs.items()])
    str_fields.update(fixed_fields)
    tag_re = re.compile(tag_name.format(**str_fields))

    # Parse fields from tags
    revs = []
    for tag in tags:
        m = tag_re.match(tag)
        groups = m.groupdict()
        revs.append([groups[f] for f in undef_fields] + [tag])

    # Return field names and a sorted list of revs
    return undef_fields, sorted(revs)

def list_test_revs(repo, tag_name, verbosity, **kwargs):
    """Get list of all tested revisions"""
    fields, revs = get_test_runs(repo, tag_name, **kwargs)
    ignore_fields = ['tag_number']
    if verbosity < 2:
        extra_fields = ['COMMITS', 'TEST RUNS']
        ignore_fields.extend(['commit_number', 'commit'])
    else:
        extra_fields = ['TEST RUNS']

    print_fields = [i for i, f in enumerate(fields) if f not in ignore_fields]

    # Sort revs
    rows = [[fields[i].upper() for i in print_fields] + extra_fields]

    prev = [''] * len(print_fields)
    prev_commit = None
    commit_cnt = 0
    commit_field = fields.index('commit')
    for rev in revs:
        # Only use fields that we want to print
        cols = [rev[i] for i in print_fields]


        if cols != prev:
            commit_cnt = 1
            test_run_cnt = 1
            new_row = [''] * (len(print_fields) + len(extra_fields))

            for i in print_fields:
                if cols[i] != prev[i]:
                    break
            new_row[i:-len(extra_fields)] = cols[i:]
            rows.append(new_row)
        else:
            if rev[commit_field] != prev_commit:
                commit_cnt += 1
            test_run_cnt += 1

        if verbosity < 2:
            new_row[-2] = commit_cnt
        new_row[-1] = test_run_cnt
        prev = cols
        prev_commit = rev[commit_field]

    print_table(rows)

def get_test_revs(repo, tag_name, **kwargs):
    """Get list of all tested revisions"""
    fields, runs = get_test_runs(repo, tag_name, **kwargs)

    revs = {}
    commit_i = fields.index('commit')
    commit_num_i = fields.index('commit_number')
    for run in runs:
        commit = run[commit_i]
        commit_num = run[commit_num_i]
        tag = run[-1]
        if not commit in revs:
            revs[commit] = TestedRev(commit, commit_num, [tag])
        else:
            assert commit_num == revs[commit].commit_number, "Commit numbers do not match"
            revs[commit].tags.append(tag)

    # Return in sorted table
    revs = sorted(revs.values(), key=attrgetter('commit_number'))
    log.debug("Found %d tested revisions:\n    %s", len(revs),
              "\n    ".join(['{} ({})'.format(rev.commit_number, rev.commit) for rev in revs]))
    return revs

def rev_find(revs, attr, val):
    """Search from a list of TestedRev"""
    for i, rev in enumerate(revs):
        if getattr(rev, attr) == val:
            return i
    raise ValueError("Unable to find '{}' value '{}'".format(attr, val))

def is_xml_format(repo, commit):
    """Check if the commit contains xml (or json) data"""
    if repo.rev_parse(commit + ':results.xml'):
        log.debug("Detected report in xml format in %s", commit)
        return True
    else:
        log.debug("No xml report in %s, assuming json formatted results", commit)
        return False

def read_results(repo, tags, xml=True):
    """Read result files from repo"""

    def parse_xml_stream(data):
        """Parse multiple concatenated XML objects"""
        objs = []
        xml_d = ""
        for line in data.splitlines():
            if xml_d and line.startswith('<?xml version='):
                objs.append(ET.fromstring(xml_d))
                xml_d = line
            else:
                xml_d += line
        objs.append(ET.fromstring(xml_d))
        return objs

    def parse_json_stream(data):
        """Parse multiple concatenated JSON objects"""
        objs = []
        json_d = ""
        for line in data.splitlines():
            if line == '}{':
                json_d += '}'
                objs.append(json.loads(json_d, object_pairs_hook=OrderedDict))
                json_d = '{'
            else:
                json_d += line
        objs.append(json.loads(json_d, object_pairs_hook=OrderedDict))
        return objs

    num_revs = len(tags)

    # Optimize by reading all data with one git command
    log.debug("Loading raw result data from %d tags, %s...", num_revs, tags[0])
    if xml:
        git_objs = [tag + ':metadata.xml' for tag in tags] + [tag + ':results.xml' for tag in tags]
        data = parse_xml_stream(repo.run_cmd(['show'] + git_objs + ['--']))
        return ([metadata_xml_to_json(e) for e in data[0:num_revs]],
                [results_xml_to_json(e) for e in data[num_revs:]])
    else:
        git_objs = [tag + ':metadata.json' for tag in tags] + [tag + ':results.json' for tag in tags]
        data = parse_json_stream(repo.run_cmd(['show'] + git_objs + ['--']))
        return data[0:num_revs], data[num_revs:]


def get_data_item(data, key):
    """Nested getitem lookup"""
    for k in key.split('.'):
        data = data[k]
    return data


def metadata_diff(metadata_l, metadata_r):
    """Prepare a metadata diff for printing"""
    keys = [('Hostname', 'hostname', 'hostname'),
            ('Branch', 'branch', 'layers.meta.branch'),
            ('Commit number', 'commit_num', 'layers.meta.commit_count'),
            ('Commit', 'commit', 'layers.meta.commit'),
            ('Number of test runs', 'testrun_count', 'testrun_count')
           ]

    def _metadata_diff(key):
        """Diff metadata from two test reports"""
        try:
            val1 = get_data_item(metadata_l, key)
        except KeyError:
            val1 = '(N/A)'
        try:
            val2 = get_data_item(metadata_r, key)
        except KeyError:
            val2 = '(N/A)'
        return val1, val2

    metadata = OrderedDict()
    for title, key, key_json in keys:
        value_l, value_r = _metadata_diff(key_json)
        metadata[key] = {'title': title,
                         'value_old': value_l,
                         'value': value_r}
    return metadata


def print_diff_report(metadata_l, data_l, metadata_r, data_r):
    """Print differences between two data sets"""

    # First, print general metadata
    print("\nTEST METADATA:\n==============")
    meta_diff = metadata_diff(metadata_l, metadata_r)
    rows = []
    row_fmt = ['{:{wid}} ', '{:<{wid}}   ', '{:<{wid}}']
    rows = [['', 'CURRENT COMMIT', 'COMPARING WITH']]
    for key, val in meta_diff.items():
        # Shorten commit hashes
        if key == 'commit':
            rows.append([val['title'] + ':', val['value'][:20], val['value_old'][:20]])
        else:
            rows.append([val['title'] + ':', val['value'], val['value_old']])
    print_table(rows, row_fmt)


    # Print test results
    print("\nTEST RESULTS:\n=============")

    tests = list(data_l['tests'].keys())
    # Append tests that are only present in 'right' set
    tests += [t for t in list(data_r['tests'].keys()) if t not in tests]

    # Prepare data to be printed
    rows = []
    row_fmt = ['{:8}', '{:{wid}}', '{:{wid}}', '  {:>{wid}}', ' {:{wid}} ', '{:{wid}}',
               '  {:>{wid}}', '  {:>{wid}}']
    num_cols = len(row_fmt)
    for test in tests:
        test_l = data_l['tests'][test] if test in data_l['tests'] else None
        test_r = data_r['tests'][test] if test in data_r['tests'] else None
        pref = ' '
        if test_l is None:
            pref = '+'
        elif test_r is None:
            pref = '-'
        descr = test_l['description'] if test_l else test_r['description']
        heading = "{} {}: {}".format(pref, test, descr)

        rows.append([heading])

        # Generate the list of measurements
        meas_l = test_l['measurements'] if test_l else {}
        meas_r = test_r['measurements'] if test_r else {}
        measurements = list(meas_l.keys())
        measurements += [m for m in list(meas_r.keys()) if m not in measurements]

        for meas in measurements:
            m_pref = ' '
            if meas in meas_l:
                stats_l = measurement_stats(meas_l[meas], 'l.')
            else:
                stats_l = measurement_stats(None, 'l.')
                m_pref = '+'
            if meas in meas_r:
                stats_r = measurement_stats(meas_r[meas], 'r.')
            else:
                stats_r = measurement_stats(None, 'r.')
                m_pref = '-'
            stats = stats_l.copy()
            stats.update(stats_r)

            absdiff = stats['val_cls'](stats['r.mean'] - stats['l.mean'])
            reldiff = "{:+.1f} %".format(absdiff * 100 / stats['l.mean'])
            if stats['r.mean'] > stats['l.mean']:
                absdiff = '+' + str(absdiff)
            else:
                absdiff = str(absdiff)
            rows.append(['', m_pref, stats['name'] + ' ' + stats['quantity'],
                         str(stats['l.mean']), '->', str(stats['r.mean']),
                         absdiff, reldiff])
        rows.append([''] * num_cols)

    print_table(rows, row_fmt)

    print()


def print_html_report(data, id_comp):
    """Print report in html format"""
    # Handle metadata
    metadata = {'branch': {'title': 'Branch', 'value': 'master'},
                'hostname': {'title': 'Hostname', 'value': 'foobar'},
                'commit': {'title': 'Commit', 'value': '1234'}
               }
    metadata = metadata_diff(data[id_comp][0], data[-1][0])


    # Generate list of tests
    tests = []
    for test in data[-1][1]['tests'].keys():
        test_r = data[-1][1]['tests'][test]
        new_test = {'name': test_r['name'],
                    'description': test_r['description'],
                    'status': test_r['status'],
                    'measurements': [],
                    'err_type': test_r.get('err_type'),
                   }
        # Limit length of err output shown
        if 'message' in test_r:
            lines = test_r['message'].splitlines()
            if len(lines) > 20:
                new_test['message'] = '...\n' + '\n'.join(lines[-20:])
            else:
                new_test['message'] = test_r['message']


        # Generate the list of measurements
        for meas in test_r['measurements'].keys():
            meas_r = test_r['measurements'][meas]
            meas_type = 'time' if meas_r['type'] == 'sysres' else 'size'
            new_meas = {'name': meas_r['name'],
                        'legend': meas_r['legend'],
                        'description': meas_r['name'] + ' ' + meas_type,
                       }
            samples = []

            # Run through all revisions in our data
            for meta, test_data in data:
                if (not test in test_data['tests'] or
                        not meas in test_data['tests'][test]['measurements']):
                    samples.append(measurement_stats(None))
                    continue
                test_i = test_data['tests'][test]
                meas_i = test_i['measurements'][meas]
                commit_num = get_data_item(meta, 'layers.meta.commit_count')
                samples.append(measurement_stats(meas_i))
                samples[-1]['commit_num'] = commit_num

            absdiff = samples[-1]['val_cls'](samples[-1]['mean'] - samples[id_comp]['mean'])
            new_meas['absdiff'] = absdiff
            new_meas['absdiff_str'] = str(absdiff) if absdiff < 0 else '+' + str(absdiff)
            new_meas['reldiff'] = "{:+.1f} %".format(absdiff * 100 / samples[id_comp]['mean'])
            new_meas['samples'] = samples
            new_meas['value'] = samples[-1]
            new_meas['value_type'] = samples[-1]['val_cls']

            new_test['measurements'].append(new_meas)
        tests.append(new_test)

    # Chart options
    chart_opts = {'haxis': {'min': get_data_item(data[0][0], 'layers.meta.commit_count'),
                            'max': get_data_item(data[-1][0], 'layers.meta.commit_count')}
                 }

    print(html.template.render(metadata=metadata, test_data=tests, chart_opts=chart_opts))


def dump_buildstats(repo, outdir, notes_ref, revs):
    """Dump buildstats of test results"""
    full_ref = 'refs/notes/' + notes_ref
    if not repo.rev_parse(full_ref):
        log.error("No buildstats found, please try running "
                  "'git fetch origin %s:%s' to fetch them from the remote",
                  full_ref, full_ref)
        return

    missing = False
    log.info("Writing out buildstats from 'refs/notes/%s' into '%s'",
              notes_ref, outdir)
    for rev in revs:
        log.debug('Dumping buildstats for %s (%s)', rev.commit_number,
                  rev.commit)
        for tag in rev.tags:
            log.debug('    %s', tag)
            try:
                bs_all = json.loads(repo.run_cmd(['notes', '--ref', notes_ref,
                                                  'show', tag + '^0']))
            except GitError:
                log.warning("Buildstats not found for %s", tag)
                missing = True
            for measurement, buildstats in bs_all.items():
                tag_base, run_id = tag.rsplit('/', 1)
                tag_base = tag_base.replace('/', '_')
                bs_dir = os.path.join(outdir, measurement, tag_base)
                if not os.path.exists(bs_dir):
                    os.makedirs(bs_dir)
                with open(os.path.join(bs_dir, run_id + '.json'), 'w') as f:
                    json.dump(buildstats, f, indent=2)
    if missing:
        log.info("Buildstats were missing for some test runs, please "
                 "run 'git fetch origin %s:%s' and try again",
                 full_ref, full_ref)


def auto_args(repo, args):
    """Guess arguments, if not defined by the user"""
    # Get the latest commit in the repo
    log.debug("Guessing arguments from the latest commit")
    msg = repo.run_cmd(['log', '-1', '--branches', '--remotes', '--format=%b'])
    for line in msg.splitlines():
        split = line.split(':', 1)
        if len(split) != 2:
            continue

        key = split[0]
        val = split[1].strip()
        if key == 'hostname':
            log.debug("Using hostname %s", val)
            args.hostname = val
        elif key == 'branch':
            log.debug("Using branch %s", val)
            args.branch = val


def parse_args(argv):
    """Parse command line arguments"""
    description = """
Examine build performance test results from a Git repository"""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=description)

    parser.add_argument('--debug', '-d', action='store_true',
                        help="Verbose logging")
    parser.add_argument('--repo', '-r', required=True,
                        help="Results repository (local git clone)")
    parser.add_argument('--list', '-l', action='count',
                        help="List available test runs")
    parser.add_argument('--html', action='store_true',
                        help="Generate report in html format")
    group = parser.add_argument_group('Tag and revision')
    group.add_argument('--tag-name', '-t',
                       default='{hostname}/{branch}/{machine}/{commit_number}-g{commit}/{tag_number}',
                       help="Tag name (pattern) for finding results")
    group.add_argument('--hostname', '-H')
    group.add_argument('--branch', '-B', default='master')
    group.add_argument('--machine', default='qemux86')
    group.add_argument('--history-length', default=25, type=int,
                       help="Number of tested revisions to plot in html report")
    group.add_argument('--commit',
                       help="Revision to search for")
    group.add_argument('--commit-number',
                       help="Revision number to search for, redundant if "
                            "--commit is specified")
    group.add_argument('--commit2',
                       help="Revision to compare with")
    group.add_argument('--commit-number2',
                       help="Revision number to compare with, redundant if "
                            "--commit2 is specified")
    parser.add_argument('--dump-buildstats', nargs='?', const='.',
                        help="Dump buildstats of the tests")

    return parser.parse_args(argv)


def main(argv=None):
    """Script entry point"""
    args = parse_args(argv)
    if args.debug:
        log.setLevel(logging.DEBUG)

    repo = GitRepo(args.repo)

    if args.list:
        list_test_revs(repo, args.tag_name, args.list)
        return 0

    # Determine hostname which to use
    if not args.hostname:
        auto_args(repo, args)

    revs = get_test_revs(repo, args.tag_name, hostname=args.hostname,
                         branch=args.branch, machine=args.machine)
    if len(revs) < 2:
        log.error("%d tester revisions found, unable to generate report",
                  len(revs))
        return 1

    # Pick revisions
    if args.commit:
        if args.commit_number:
            log.warning("Ignoring --commit-number as --commit was specified")
        index1 = rev_find(revs, 'commit', args.commit)
    elif args.commit_number:
        index1 = rev_find(revs, 'commit_number', args.commit_number)
    else:
        index1 = len(revs) - 1

    if args.commit2:
        if args.commit_number2:
            log.warning("Ignoring --commit-number2 as --commit2 was specified")
        index2 = rev_find(revs, 'commit', args.commit2)
    elif args.commit_number2:
        index2 = rev_find(revs, 'commit_number', args.commit_number2)
    else:
        if index1 > 0:
            index2 = index1 - 1
        else:
            log.error("Unable to determine the other commit, use "
                      "--commit2 or --commit-number2 to specify it")
            return 1

    index_l = min(index1, index2)
    index_r = max(index1, index2)

    rev_l = revs[index_l]
    rev_r = revs[index_r]
    log.debug("Using 'left' revision %s (%s), %s test runs:\n    %s",
              rev_l.commit_number, rev_l.commit, len(rev_l.tags),
              '\n    '.join(rev_l.tags))
    log.debug("Using 'right' revision %s (%s), %s test runs:\n    %s",
              rev_r.commit_number, rev_r.commit, len(rev_r.tags),
              '\n    '.join(rev_r.tags))

    # Check report format used in the repo (assume all reports in the same fmt)
    xml = is_xml_format(repo, revs[index_r].tags[-1])

    if args.html:
        index_0 = max(0, index_r - args.history_length)
        rev_range = range(index_0, index_r + 1)
    else:
        # We do not need range of commits for text report (no graphs)
        index_0 = index_l
        rev_range = (index_l, index_r)

    # Read raw data
    log.debug("Reading %d revisions, starting from %s (%s)",
              len(rev_range), revs[index_0].commit_number, revs[index_0].commit)
    raw_data = [read_results(repo, revs[i].tags, xml) for i in rev_range]

    data = []
    for raw_m, raw_d in raw_data:
        data.append((aggregate_metadata(raw_m), aggregate_data(raw_d)))

    # Re-map list indexes to the new table starting from index 0
    index_r = index_r - index_0
    index_l = index_l - index_0

    # Print report
    if not args.html:
        print_diff_report(data[index_l][0], data[index_l][1],
                          data[index_r][0], data[index_r][1])
    else:
        print_html_report(data, index_l)

    # Dump buildstats
    if args.dump_buildstats:
        notes_ref = 'buildstats/{}/{}/{}'.format(args.hostname, args.branch,
                                                 args.machine)
        dump_buildstats(repo, 'oe-build-perf-buildstats', notes_ref,
                        [rev_l, rev_r])
                        #revs_l.tags + revs_r.tags)

    return 0

if __name__ == "__main__":
    sys.exit(main())