"""
telemetry.py
Created on: July 26, 2017
Proprietary and confidential information of Oxford Nanopore Technologies, Ltd
All rights reserved; (c)2017: Oxford Nanopore Technologies, Limited
"""
import ast
from collections import defaultdict
import copy
import json
import logging
from math import floor
import os
import uuid

from . import __version__ as ALBACORE_VERSION
from .telemetry_utils import Histogram

# For histograms
# The bin name maps to (min_bin_value, max_bin_value, num_bins)
# Note this is quite simplistic -- you're going to want to make sure that
# num_bins divides (max_bin_value - min_bin_value) in a nice way.
BIN_CONFIG = {'qscore_dist_temp': (0, 100, 200, 'mean_qscore'),
              'speed_events_per_second_dist_temp': (0, 1000, 1000, 'speed'),
              'speed_bases_per_second_dist_temp': (0, 2000, 1000, 'speed'),
              'seq_len_events_dist_temp': (0, 1000000, 1000, 'length'),
              'seq_len_bases_dist_temp': (0, 1000000, 1000, 'length'),
              # 1dsq
              'qscore_dist': (0, 100, 200, 'mean_qscore'),
              'seq_len_ratio_bases_dist': (0, 5, 0.1, 'ratio'),
              'open_pore_length_seconds_dist': (0, 5, 0.05, 'length')
              }

BASECALL_1D_COLLECTION = 'basecall_1d'
BASECALL_1DSQ_COLLECTION = 'basecall_1d2'

# Get rid of a few fields which we don't need from the options.
ALBACORE_OPTS_FIELDS_TO_REMOVE = ['input', 'save_path', 'data_path',
                                  'default_path', 'read_path']


class TelemetryAggregator(object):
    """ Aggregates per-read summary information into a single set of albacore
    summaries.

    This produces a high-level summary of the results of a particular call to
    albacore. Each set of telemetry output is broken down into an overall
    summary and an "acquisition-segment" summary which includes only data from
    reads produced within that segment (e.g. if the segment length is an hour
    then telemetry will contain summaries for only data produced during a
    particular hour).

    """

    def __init__(self, segment_duration=60,
                 software_name='albacore-basecalling', logger_name=__name__,
                 albacore_opts=None, include_1d_basecalling=True,
                 include_1dsq_basecalling=False, include_calibration=False,
                 include_alignment=False, include_barcoding=False,
                 analysis_id=None):
        """ Constructs the object.

        :param segment_duration: The duration of the acquisition-segment, in
            minutes, to use.
        :param software_name: The name to include in high-level summary packets.
        :param logger_name: Name of the logger to use.
        :param albacore_opts: Albacore ConfigParser options.
        :param include_1d_basecalling: Whether to aggregate 1d basecalling info.
        :param include_1dsq_basecalling: Whether to aggregate 1dsq basecalling
            info.
        :param include_calibration: Whether to aggregate calibration strand
            info.
        :param include_alignment: Whether to aggregate alignment info.
        :param include_barcoding: Whether to aggregate barcoding info.
        :param analysis_id: Optional albacore-acquisition id to assign to
            this aggregator (when multiple telemetries should share the same
            one).

        """
        self.segment_duration = segment_duration
        assert(self.segment_duration > 0)
        self.summary_collection = {}
        self.software = {'name': software_name,
                         'version': ALBACORE_VERSION}
        if analysis_id is not None:
            self.analysis_id = analysis_id
        else:
            self.analysis_id = generate_analysis_id()
        basecalling_analysis = 'non'
        if include_1d_basecalling:
            basecalling_analysis = '1d'
            if include_1dsq_basecalling:
                basecalling_analysis += '_and_1dsq'
        elif include_1dsq_basecalling:
            basecalling_analysis = '1dsq'
        self.software['analysis'] = '{}_basecalling'.format(basecalling_analysis)
        self.logger_name = logger_name
        self.albacore_opts = {}
        if albacore_opts is not None:
            self.albacore_opts = vars(copy.deepcopy(albacore_opts))
            # Remove bits with local paths in them
            for field in ALBACORE_OPTS_FIELDS_TO_REMOVE:
                if field in self.albacore_opts:
                    del self.albacore_opts[field]
            # We'd like to keep the config name, but we'll remove extraneous
            # bits of the path
            if 'config' in self.albacore_opts:
                self.albacore_opts['config'] = \
                    os.path.basename(self.albacore_opts['config'])
        self.include_1d_basecalling = include_1d_basecalling
        self.include_1dsq_basecalling = include_1dsq_basecalling
        self.include_calbration = include_calibration
        self.include_alignment = include_alignment
        self.include_barcoding = include_barcoding

    def add_summary_data(self, summary_line, tracking_id_data, context_tags):
        """ Accepts a single line of summary information and updates the
        various telemetry collections.

        :param summary_line: A dictionary of summary data, such as would be
            written into a seqchem_summary.txt file.
        :param tracking_id_data: Tracking id dictionary for summary_line.
        :param context_tags: Context tag dictionary for summary_line.
        :returns: True if the line was added successfully.
        :rtype: boolean
        :raises: KeyError if one of the required summary values is not present.

        """
        logger = logging.getLogger(self.logger_name)
        # We need to update both the full-run and the per-acquisition-segment
        # collections.
        try:
            run_id = tracking_id_data['run_id']
            start_time_seconds = summary_line['start_time']
        except KeyError as e:
            logger.error('Failed to load key: {} from summary data.'.format(e))
            raise
        acquisition_segment = _get_acquisition_segment(start_time_seconds,
                                                       self.segment_duration)
        # Full-run telemetry always has an acquisition segment of one
        for key in [run_id,
                    _generate_segment_summary_key(run_id, acquisition_segment)]:
            if key not in self.summary_collection:
                self.summary_collection[key] = {}
            try:
                self._add_common_data(summary_line, tracking_id_data,
                                      context_tags, acquisition_segment, key)
                if self.include_1d_basecalling:
                    self._add_segmentation_data(key)
                    self._add_basecall_1d_data(summary_line, key)
                if self.include_1dsq_basecalling:
                    self._add_basecall_1dsq_data(summary_line, key)
            except KeyError as e:
                logger.error('Failed to load key: {} into telemetry '
                             'data.'.format(e))
                raise
        return True

    def add_1dsq_candidate_data(self, read1_summary_line, read2_summary_line,
                                run_id):
        """ Add data for 1dsq candidates.

        This is slightly different to adding a normal summary line, as we're
        missing the tracking_id and context_tags collections (because we're not
        reading in fast5 files -- we're parsing the summary file from a 1d
        run). This doesn't interfere with normal use of add_summary_data.

        :param read1_summary_line: 1d summary line for the first candidate read.
        :param read2_summary_line: 1d summary line for the second candidate
            read.
        :param run_id: The run_id for the two reads.
        :returns: True if the entry was added successfully.
        :rtype: boolean
        :raises: KeyError if one of the required summary values is not present.

        """
        logger = logging.getLogger(self.logger_name)
        start_time_seconds = read1_summary_line['start_time']
        acquisition_segment = _get_acquisition_segment(start_time_seconds,
                                                       self.segment_duration)
        for key in [run_id,
                    _generate_segment_summary_key(run_id, acquisition_segment)]:
            if key not in self.summary_collection:
                self.summary_collection[key] = {}
            try:
                self._add_basecall_1dsq_candidate_data(read1_summary_line,
                                                       read2_summary_line, key)
            except KeyError as e:
                logger.error('Failed to load key: {} into telemetry '
                         'data.'.format(e))
                raise
        return True

    def add_json(self, json_file):
        """ Read in an existing json telemetry file and adds the data in it to
        the existing TelemetryAggregation

        :param json_file: filename of the json file to add

        """
        # There are two cases where we'll get a list of dicts: the top-level
        # list of per-run and per-acqusition-segment one and when we have
        # Histograms.
        class HistogramDecoder(json.JSONDecoder):
            def decode(self, obj):
                decoding = ast.literal_eval(obj)
                is_histogram = (type(decoding) is list
                                and len(decoding) > 0
                                and type(decoding[0]) is dict
                                and len(decoding[0]) == 2
                                and 'count' in decoding[0])
                if is_histogram:
                    return Histogram.from_dict(decoding)
                return json.JSONDecoder.decode(self, obj)

        def decode_histograms(obj):
            def is_histogram(thing):
                return (type(thing) is list
                        and len(thing) > 0
                        and type(thing[0]) is dict
                        and len(thing[0]) == 2
                        and 'count' in thing[0])
            if is_histogram(obj):
                return Histogram.from_dict(obj)
            elif isinstance(obj, list):
                return [decode_histograms(x) for x in obj]
            elif isinstance(obj, dict):
                return {k: decode_histograms(v) for k, v in obj.items()}
            else:
                return obj

        logger = logging.getLogger(self.logger_name)
        with open(json_file, 'r') as input_file:
            try:
                data_list = json.loads(input_file.read())
            except (KeyError, IndexError) as e:
                logger.error('Failed to import file {} due to key / index '
                             'error: {}'.format(input_file, e))
                raise
        data_list = decode_histograms(data_list)

        # We should have a list of dicts now
        for data in data_list:
            run_id = data['tracking_id']['run_id']
            aggregation = data['aggregation']
            if aggregation != 'cumulative':
                segment_number = data['segment_number']
                key = _generate_segment_summary_key(run_id, segment_number)
            else:
                key = run_id
            if key not in self.summary_collection:
                self.summary_collection[key] = data
            else:
                target = self.summary_collection[key]
                try:
                    _merge_entries(data, target, aggregation)
                except KeyError as e:
                    logger.error('Failed to add data from key {}:{} into '
                                 'existing collection.'.format(key, e))
                    continue

    def to_json(self):
        """ Returns the current set of telemetry in json format, suitable for
        writing to a file.

        The output is a (string) list of json entries, where each entry
        corresponds either to:

            * A full-run aggregation, or
            * An acquisition-segment aggregation.

        :returns: json representation of the telemetry
        :rtype: string

        """
        output_list = list(self.json_entries())
        return '[' + ','.join(output_list) + ']'

    def json_entries(self):
        """ json generator over the individual telemetry entries.

        This is suitable for use when pinging each telemetry entry
        individually. As with :py:meth:`to_json` each entry corresponds either
        to:

            * A full-run aggregation, or
            * An acquisition-segment aggregation.

        Note that a unique-per-entry field, msg_id, is added to the tracking_id
        section as part of this process

        :returns: Individual telemetry entries in json format
        :rtype: (Generator of) string

        """
        class HistogramEncoder(json.JSONEncoder):
            def default(self, obj):
                if isinstance(obj, Histogram):
                    return obj.to_list_of_dicts()
                return json.JSONEncoder.default(self, obj)

        logger = logging.getLogger(__name__)
        for key, entry in self.summary_collection.items():
            # If we've only evaluated 1dsq candidates then we won't have any
            # tracking_id entries, and therefore we don't want to ping them.
            if 'tracking_id' not in entry:
                logger.warning('No tracking_id found in entry with '
                               'key {}'.format(key))
                continue
            entry['tracking_id']['msg_id'] = str(uuid.uuid4())
            yield json.dumps(entry, cls=HistogramEncoder)

    def _add_common_data(self, summary_line, tracking_id_data, context_tags,
                         acquisition_segment, summary_collection_key):
        """ summary_collection_key will either be a tuple of
        (run_id, acquisition_segment) for acquisition-segment pings, or just the
        run_id for cumulative-pings.

        """
        # Copy the collection out so we can avoid changes on error
        collection = copy.deepcopy(
            self.summary_collection[summary_collection_key])
        # We'll use tracking_id as a general indicator of whether or not we need
        # to initialize the collection
        if 'tracking_id' not in collection:
            collection['tracking_id'] = tracking_id_data
            collection['software'] = self.software
            collection['context_tags'] = context_tags
            collection['segment_type'] = 'albacore-acquisition'
            collection['segment_number'] = acquisition_segment
            if type(summary_collection_key) is tuple:
                collection['segment_duration'] = self.segment_duration
                collection['aggregation'] = 'segment'
            else:
                seg_dur = self.segment_duration * acquisition_segment
                collection['segment_duration'] = seg_dur
                collection['aggregation'] = 'cumulative'
            collection['latest_run_time'] = summary_line['start_time']
            collection['run_id'] = summary_line['run_id']
            collection['read_count'] = 0
            collection['albacore_opts'] = self.albacore_opts
            collection['albacore_analysis_id'] = self.analysis_id
            if self.include_1d_basecalling:
                collection['reads_per_channel_dist'] = Histogram(1, 3001, 3000,
                                                                 'channel')
                collection['channel_count'] = 0
                collection['levels_sums'] = {'open_pore_level_sum': 0,
                                         'count': 0}
        # For cumulative aggregation segment_duration needs to be updated to be
        # the latest read we've seen
        if type(summary_collection_key) is not tuple:
            new_segment_duration = max(collection['segment_duration'],
                                       acquisition_segment *
                                       self.segment_duration)
            collection['segment_duration'] = new_segment_duration

        # Update the collection for this particular summary_line
        collection['latest_run_time'] = max(collection['latest_run_time'],
                                            summary_line['start_time'])
        collection['read_count'] += 1
        if self.include_1d_basecalling:
            collection['reads_per_channel_dist'].increment(
                summary_line['channel'])
            collection['channel_count'] = \
                collection['reads_per_channel_dist'].num_entries()
            collection['levels_sums']['open_pore_level_sum'] += \
                summary_line['median_before']
            collection['levels_sums']['count'] += 1

        # Write back
        self.summary_collection[summary_collection_key] = collection

    def _add_segmentation_data(self, summary_collection_key):
        """ Nothing interesting here. """
        collection = self.summary_collection[summary_collection_key]
        if 'segmentation' not in collection:
            collection['segmentation'] = {'component_index': 0}

    def _add_basecall_1d_data(self, summary_line, summary_collection_key):
        collection = self.summary_collection[summary_collection_key]
        if BASECALL_1D_COLLECTION not in collection:
            collection[BASECALL_1D_COLLECTION] = {}
            bc1d_coll = collection[BASECALL_1D_COLLECTION]
            bc1d_coll['exit_status_dist'] = defaultdict(int)
            bc1d_coll['component_index'] = 1
            bc1d_coll['seq_len_bases_sum_temp'] = 0
            bc1d_coll['read_len_events_sum_temp'] = 0
            bc1d_coll['qscore_sum_temp'] = {'sum': 0,
                                            'count': 0,
                                            'mean': 0}
            for entry in ['qscore_dist_temp',
                          'speed_events_per_second_dist_temp',
                          'speed_bases_per_second_dist_temp',
                          'seq_len_events_dist_temp',
                          'seq_len_bases_dist_temp']:
                bin_config = BIN_CONFIG[entry]
                bc1d_coll[entry] = Histogram(*bin_config)
            bc1d_coll['strand_median_pa'] = {'sum': 0,
                                             'count': 0,
                                             'mean': 0}
            bc1d_coll['strand_sd_pa'] = {'sum': 0,
                                         'count': 0,
                                         'mean': 0}

        # Update the appropriate collection
        # Deep copy so we can make our changes atomic
        bc1d_coll = copy.deepcopy(collection[BASECALL_1D_COLLECTION])
        _update_1d_exit_status(summary_line, bc1d_coll['exit_status_dist'])
        bc1d_coll['seq_len_bases_sum_temp'] += \
            summary_line['sequence_length_template']
        bc1d_coll['read_len_events_sum_temp'] += \
            summary_line['num_events']
        _increment_sum_count_mean(summary_line['mean_qscore_template'],
                                  bc1d_coll['qscore_sum_temp'])
        _increment_sum_count_mean(summary_line['median_current_template'],
                                  bc1d_coll['strand_median_pa'])
        _increment_sum_count_mean(summary_line['median_sd_template'],
                                  bc1d_coll['strand_sd_pa'])
        # We'll do all the histograms in one go after some setup
        speed_eps_template = 0
        speed_bps_template = 0
        if summary_line['template_duration'] > 0:
            speed_eps_template = (summary_line['num_events_template'] /
                                  summary_line['template_duration'])
            speed_bps_template = (summary_line['sequence_length_template'] /
                                  summary_line['template_duration'])
        entry_to_val = zip(['qscore_dist_temp',
                            'speed_events_per_second_dist_temp',
                            'speed_bases_per_second_dist_temp',
                            'seq_len_events_dist_temp',
                            'seq_len_bases_dist_temp'],
                           [summary_line['mean_qscore_template'],
                            speed_eps_template,
                            speed_bps_template,
                            summary_line['num_called_template'],
                            summary_line['sequence_length_template']])
        for entry, val in entry_to_val:
            bc1d_coll[entry].increment(val)

        # Write back
        collection[BASECALL_1D_COLLECTION] = bc1d_coll

    def _add_basecall_1dsq_data(self, summary_line, summary_collection_key):
        collection = self.summary_collection[summary_collection_key]
        _initialize_1dsq_collection(collection)
        bc_coll = copy.deepcopy(collection[BASECALL_1DSQ_COLLECTION])

        _update_1dsq_exit_status(summary_line, bc_coll['exit_status_dist'])
        bc_coll['seq_len_bases_sum'] += summary_line['sequence_length_2d']
        _increment_sum_count_mean(summary_line['mean_qscore_2d'],
                                  bc_coll['qscore_sum'])
        bc_coll['candidate_pairs_count'] += 1
        bc_coll['qscore_dist'].increment(summary_line['mean_qscore_2d'])

        collection[BASECALL_1DSQ_COLLECTION] = bc_coll

    def _add_basecall_1dsq_candidate_data(self, read1_summary_line,
                                          read2_summary_line,
                                          summary_collection_key):
        collection = self.summary_collection[summary_collection_key]
        _initialize_1dsq_collection(collection)
        bc_coll = copy.deepcopy(collection[BASECALL_1DSQ_COLLECTION])

        seq_len_read1 = read1_summary_line['sequence_length_template']
        seq_len_read2 = read2_summary_line['sequence_length_template']
        # We should have already sanity-checked our data, but better safe than
        # sorry.
        if seq_len_read1 == 0 or seq_len_read2 == 0:
            return
        seq_len_ratio = seq_len_read1 / seq_len_read2
        start_time_read2 = read2_summary_line['start_time']
        end_time_read1 = read1_summary_line['start_time'] + \
            read1_summary_line['duration']
        open_pore_length = start_time_read2 - end_time_read1
        bc_coll['seq_len_ratio_bases_dist'].increment(seq_len_ratio)
        bc_coll['open_pore_length_seconds_dist'].increment(open_pore_length)

        collection[BASECALL_1DSQ_COLLECTION] = bc_coll


def generate_analysis_id():
    """ Generate a unique analysis id to associate with a set of telemetry.

    Each telemetryAggregator instance should have exactly one of these, so that
    it's possible to find the telemetry generated by a particular call to
    albacore.

    :returns: uuid to associate with a set of telemetry
    :rtype: string

    """
    return str(uuid.uuid4())


def _generate_segment_summary_key(run_id, acquisition_segment):
        return (run_id, acquisition_segment)


def _increment_sum_count_mean(value, dest_dict, count=1):
    """ Increments telemetry dicts which have "sum", "count", and a calculated
    mean.
    """
    if value > 0:
        dest_dict['sum'] += value
        dest_dict['count'] += count
        dest_dict['mean'] = dest_dict['sum'] / dest_dict['count']


def _update_1d_exit_status(summary_line, exit_status_dict):
    exit_status = 'pass'
    if summary_line['sequence_length_template'] == 0:
        exit_status = 'fail:basecall_failed'
    elif not summary_line['passes_filtering']:
        exit_status = 'fail:qscore_filter'
    exit_status_dict[exit_status] += 1


def _update_1dsq_exit_status(summary_line, exit_status_dict):
    exit_status = 'pass'
    if summary_line['sequence_length_2d'] == 0:
        exit_status = 'fail:basecall_failed'
    elif not summary_line['passes_filtering']:
        exit_status = 'fail:qscore_filter'
    exit_status_dict[exit_status] += 1


def _initialize_1dsq_collection(collection):
    coll_key = BASECALL_1DSQ_COLLECTION
    if coll_key not in collection:
        collection[coll_key] = {}
        bc_coll = collection[coll_key]
        bc_coll['exit_status_dist'] = defaultdict(int)
        bc_coll['component_index'] = 2
        bc_coll['seq_len_bases_sum'] = 0
        bc_coll['qscore_sum'] = {'sum': 0,
                                 'count': 0,
                                 'mean': 0}
        bc_coll['candidate_pairs_count'] = 0
        for entry in ['qscore_dist',
                      'seq_len_ratio_bases_dist',
                      'open_pore_length_seconds_dist']:
            bin_config = BIN_CONFIG[entry]
            bc_coll[entry] = Histogram(*bin_config)


def _merge_entries(source_dict, target_dict, aggregation):
    """ Merge two summary collections together.

    This assumes the collections have been initialized already (i.e. we don't
    have to worry about checking for key existence).

    """
    # Special cases first
    special_fields = ['segment_duration', 'latest_run_time', 'segment_number',
                      'channel_count', 'component_index']
    if aggregation == 'cumulative':
        target_dict['segment_duration'] = max(target_dict['segment_duration'],
                                              source_dict['segment_duration'])
    target_dict['latest_run_time'] = max(target_dict['latest_run_time'],
                                         source_dict['latest_run_time'])
    _combine_fields(source_dict, target_dict, special_fields)
    target_dict['channel_count'] = \
        target_dict['reads_per_channel_dist'].num_entries()


def _combine_fields(source_dict, target_dict, special_fields):
    def is_sum_count_mean_dict(the_dict):
        return set(the_dict.keys()) == set(['sum', 'count', 'mean'])
    for key in source_dict:
        if key in special_fields:
            continue
        value = source_dict[key]
        if type(value) is str:
            continue
        elif type(value) is Histogram:
            target_dict[key].combine(value)
        elif (type(source_dict[key]) is dict and
              is_sum_count_mean_dict(source_dict[key])):
            # All of these should be sum-count-mean
            _increment_sum_count_mean(value['sum'],
                                      target_dict[key],
                                      value['count'])
        elif type(source_dict[key]) is dict:
            _combine_fields(source_dict[key], target_dict[key],
                            special_fields)
        else:  # These should all be numbers
            target_dict[key] += source_dict[key]


def _get_acquisition_segment(start_time_seconds, segment_duration_minutes):
    start_time_minutes = start_time_seconds / 60
    acquisition_segment = floor(start_time_minutes / segment_duration_minutes)
    acquisition_segment += 1  # Segments are one-indexed
    return acquisition_segment
