"""
pipeline.py
Created on: November 14, 2016
Proprietary and confidential information of Oxford Nanopore Technologies, Limited
All rights reserved; (c)2016: Oxford Nanopore Technologies, Limited
"""
import itertools
import logging
import numpy as np
from copy import deepcopy
from configparser import ConfigParser
from albacore.pipeline_core import PipelineCore, get_debug_level as get_core_debug_level


class Pipeline(object):
    """ ONT basecalling pipeline. This uses a boost::python wrapper around the C++
    albacore project to provide a 1D and 2D basecalling pipeline for read data. The
    calling script is for resource management. This code provides a multi-threaded
    pipeline that uses a DataHandler object to process read data and save the
    results.
    """

    def __init__(self, worker_threads, data_handler, config_file=None,
                 analysis_obj=None, reverse_direction=False,
                 u_substitution=False, telemetry_aggregator=None,
                 logger_name=__name__):
        """ Construct a new Pipeline object.
        
        :param worker_threads: The number of worker threads to use. 0 means
            single-threaded mode.
        :param data_handler: A data handler object for read file IO.
        :param config_file: Configuration file to use for analysis.
        :param analysis_obj: Provide a non-standard object for performing
            the data analysis. If None, then it will use the PipelineCore
            class.
        :param reverse_direction: reverse output fast1 sequence direction for 3'->5' reads
        :param u_substitution: substitute U for T in output fastq to support RNA
        """
        self.data_handler = data_handler
        self.logger_name = logger_name
        logger = logging.getLogger(self.logger_name)
        if config_file is not None:
            config = ConfigParser()
            config.read(config_file)
            desc_file = config.get('pipeline', 'desc_file')

        # Initialize the basecall pipeline.
        if analysis_obj is None:
            self.core = PipelineCore(desc_file, config_file, worker_threads)
            warnings = self.core.get_warnings()
            for warning in warnings:
                logger.warning(warning)
        else:
            self.core = analysis_obj
        self.data_id_mapping = {}
        self.cached_data = {}
        self.basecall_type = data_handler.basecall_type
        self.reverse_direction = reverse_direction
        self.u_substitution = u_substitution
        self.telemetry_aggregator = telemetry_aggregator

    def workers_ready(self):
        """ Returns the number of free workers (if any).
        
        The cache_completed_calls() method should be called before
        calling this. Otherwise, workers which have finished since
        the last call will not be reported as complete.
        """
        return self.core.workers_ready()

    def cache_completed_calls(self):
        """ Caches basecall data for any completed reads.

        :returns: The number of completed reads.
        
        This should be called before calling the workers_ready() method.
        """
        self.cached_data = self.core.get_results()
        return len(self.cached_data.keys())
    
    def finish_all_jobs(self):
        """ Finish all currently processing reads, and cache the results.
        
        :returns: The number of completed reads.
        
        This should only be called once all worker threads are ready, and
        no more data will be sent to the pipeline.
        """
        self.cached_data = self.core.get_results()
        self.core.finish_all_jobs()
        self.cached_data.update(self.core.get_results())
        return len(self.cached_data.keys())

    def submit_read(self, name):
        """ Submit a read file for processing.

        :param name: The identifier of the read to submit, as required by the
            registered DataHandler object.
        :returns: True if the data was successfully submitted. False if the
                  data handler could not provide the requested data. 
        """
        num_free = self.workers_ready()
        if num_free == 0:
            raise Exception('Attempted to submit a read file, but no workers are available.')
        try:
            data, meta = self.data_handler.load_read_data(name)
            self.data_id_mapping[meta.data_id] = meta
        except Exception:
            return False
        parms = meta.get_read_data_dict_for_pipeline()
        parms.update(data)
        self.core.pass_data(parms)
        return True

    def process_cached_calls(self):
        """ Process any cached basecall data.

        :returns: Summary data for output.

        This should always be called after the cache_completed_calls() method
        has been called, if that method returned True. Otherwise cached basecall
        data will be discarded the next time cache_completed_calls() is called.

        The returned summary data is a list of dicts, with one entry per processed
        read.
        """
        logger = logging.getLogger(self.logger_name)
        summary_data = []
        reads = [read for read in sorted(self.cached_data.keys()) if read != 'error']
        data_id_to_remove = None
        if 'error' in self.cached_data:
            read_id = 'error'
            read_data = self.cached_data[read_id]
            data_id_to_remove = read_data['error']['data_id']
            self._process_read(read_id, read_data)
        for read_id in reads:
            read_data = self.cached_data[read_id]
            read_meta, summary_line = self._process_read(read_id, read_data)
            summary_data.append(summary_line)
            if self.telemetry_aggregator is not None:
                try:
                    self.telemetry_aggregator.add_summary_data(
                        summary_line,
                        read_meta.tracking_id,
                        read_meta.context_tags)
                except Exception as e:
                    logger.error('Error inserting read {} into '
                                 'telemetry'.format(read_id))
                    logger.error(e)
        # if there has been no proper message from a failing read, then close the read
        if data_id_to_remove in self.data_id_mapping:
            read_data = self.cached_data['error']
            general_data = self._extract_general_data(data_id_to_remove, read_data)
            summary_data.append(dict(general_data))
            self.data_handler.finish_read(data_id_to_remove)
            del self.data_id_mapping[data_id_to_remove]
        self.cached_data = {}
        return summary_data

    ##############################
    #
    # Private methods below.
    #
    ##############################

    def _collapse_subdicts(self, data, format_str='{sub_key}_{main_key}'):
        '''
        Collapses a dict-in-dict structure into one dict
        :returns dict whose keys are made from format string, where
                 main_key is a key in data, and sub_key is a key in
                 data[main_key].
        '''
        ret = {}
        for main_key, sub_dict in data.items():
            for sub_key, content in sub_dict.items():
                ret[format_str.format(main_key=main_key, sub_key=sub_key)] = content
        return ret

    def _prepare_alignment_data_for_summary(self, data, section_name):
        for section, section_data in data.items():
            if 'read_id' in section_data:
                section_data.pop('read_id')
            if 'channel' in section_data:
                section_data.pop('channel')
            if section_name + '_sam_output' in section_data:
                section_data.pop(section_name + '_sam_output')
        return self._collapse_subdicts(data)

    def _process_read(self, read_id, read_data):
        if 'data_id' not in read_data[list(read_data.keys())[0]]:
            raise Exception("data_id not found. Keys are: " + str(list(read_data.keys())[0]))
        data_id = read_data[list(read_data.keys())[0]]['data_id']
        read_info = self.data_id_mapping[data_id]
        read_data['read_id'] = read_id

        general_data = self._extract_general_data(data_id, read_data)

        log_data = self._extract_log_data(data_id, read_data)
        self.data_handler.output_log(data_id, log_data)

        if read_id == 'error':
            summary_out = dict(general_data)
            summary_out['error_message'] = read_data['error'].get('error_message', 'unknown error')
            summary_out['log'] = log_data['log'] if log_data else None
            return summary_out

        segmentation_data = self._extract_segmentation_data(data_id, read_data)
        basecall_1d_data = self._extract_basecall_1d_data(data_id, read_data)
        basecall_2d_data = self._extract_basecall_2d_data(data_id, read_data)
        calib_data = self._extract_calib_data(data_id, read_data)
        barcoding_data = self._extract_barcoding_data(data_id, read_data)
        alignment_data = self._extract_alignment_data(data_id, read_data)

        segmentation_tag = self.data_handler.output_segmentation(data_id, segmentation_data)
        basecall_1d_tag = self.data_handler.output_basecall_1d(data_id, basecall_1d_data, segmentation_tag)
        basecall_2d_tag = None
        if self.basecall_type in ['full_2d', '1dsq']:
            basecall_2d_tag = self.data_handler.output_basecall_2d(data_id, basecall_2d_data, segmentation_tag, basecall_1d_tag)
        if calib_data:
            self.data_handler.output_calib_detection(data_id, calib_data, segmentation_tag, basecall_1d_tag, basecall_2d_tag)
            calib_data = self._prepare_alignment_data_for_summary(calib_data, 'calibration_strand')
        if barcoding_data:
            self.data_handler.output_barcoding(data_id, barcoding_data, segmentation_tag, basecall_1d_tag, basecall_2d_tag)
        if alignment_data:
            self.data_handler.output_alignment(data_id, alignment_data, segmentation_tag, basecall_1d_tag, basecall_2d_tag)
            alignment_data = self._prepare_alignment_data_for_summary(alignment_data, 'alignment')

        filtering_dict = {
            'passes_filtering': self.data_handler.output_filtering(data_id)}

        self.data_handler.finish_read(data_id)

        basecall_1d_data.pop('template_data', None)
        basecall_1d_data.pop('complement_data', None)
        basecall_1d_data.pop('template_fastq', None)
        basecall_1d_data.pop('complement_fastq', None)
        basecall_2d_data.pop('2d_alignment_data', None)
        basecall_2d_data.pop('2d_labl_data', None)
        basecall_2d_data.pop('2d_fastq', None)
        basecall_2d_data.pop('2d_movement', None)

        summary_out = dict(general_data)
        summary_out.update(segmentation_data)
        summary_out.update(basecall_1d_data)
        summary_out.update(basecall_2d_data)
        summary_out.update(calib_data)
        summary_out.update(barcoding_data)
        summary_out.update(alignment_data)
        summary_out.update(filtering_dict)
        del self.data_id_mapping[data_id]
        return read_info, summary_out

    def _extract_log_data(self, data_id, read_data):
        if 'log' in read_data:
            return read_data['log']
        else:
            return {}

    def _extract_general_data(self, data_id, read_data):
        read_info = self.data_id_mapping[data_id]
        read_id = read_info.read_id
        num_events = 0
        if 'event_detector_summary' in read_data:
            num_events = read_data['event_detector_summary'].get('num_events', 0)
        elif 'basecall_1d_callback' in read_data:
            num_events = read_data['basecall_1d_callback'].get('called_events', 0)
        duration = 0
        if 'data_trimmer_summary' in read_data:
            duration = read_data['data_trimmer_summary'].get('duration', 0)
        data = read_info.get_read_data_dict_for_summary()
        data['run_id'] = data.pop('run_id', 'unknown')
        data['channel'] = data.pop('channel_id', 'unknown')
        data['start_time'] = round(data.get('start_time', 0.0) / data.get('sampling_rate', 1.0), 5)
        data['duration'] = round(duration / data.get('sampling_rate', 1.0), 5)
        data['sampling_rate'] = round(data.get('sampling_rate', 1.0), 2)
        data['num_events'] = num_events
        return data

    def _extract_segmentation_data(self, data_id, read_data):
        if 'data_trimmer_summary' not in read_data:
            return {}

        read_info = self.data_id_mapping[data_id]
        sampling_rate = read_info.sampling_rate
        start_time = self.data_id_mapping[data_id].start_time

        num_events_template = 0
        num_events_complement = 0
        duration_template = 0
        duration_complement = 0

        if 'event_detector_summary' in read_data:
            num_events_template = read_data['event_detector_summary'].get('num_events', 0)
        elif 'basecall_1d_callback' in read_data:
            num_events_template = read_data['basecall_1d_callback'].get('called_events', 0)
        duration_template = read_data['data_trimmer_summary'].get('duration') - read_data['data_trimmer_summary'].get('stall_duration')

        template_data = read_data.get('basecall_1d_callback', read_data.get('basecall_template_callback', None))
        if template_data is not None:
            has_template = True
            template_start_time = template_data.get('strand_start_time', 0)
        else:
            has_template = False
        complement_data = read_data.get('basecall_complement_callback', None)
        if complement_data is not None:
            has_complement = True
            complement_start_time = complement_data.get('strand_start_time', 0)
        else:
            has_complement = False
        data = {'read_id': read_data['read_id'],
                'has_template': has_template,
                'has_complement': has_complement}
        if has_template:
            data.update({'num_events_template': num_events_template,
                         'first_sample_template': template_start_time,
                         'duration_template': duration_template,
                         'template_start': round(template_start_time / sampling_rate, 5),
                         'template_duration': round(duration_template / sampling_rate, 5)})
        if has_complement:
            data.update({'num_events_complement': num_events_complement,
                         'first_sample_complement': complement_start_time,
                         'duration_complement': duration_complement,
                         'complement_start': round(complement_start_time / sampling_rate, 5),
                         'complement_duration': round(duration_complement / sampling_rate, 5)})
        return data

    def _extract_basecall_1d_data(self, data_id, read_data):
        if 'event_detector_summary' not in read_data and 'basecall_1d_callback' not in read_data:
            return {'read_id': read_data['read_id']}
        num_events = {'template': 0}
        if 'event_detector_summary' in read_data:
            num_events['template'] = read_data['event_detector_summary'].get('num_events', 0)
        else:
            num_events['template'] = read_data['basecall_1d_callback'].get('called_events', 0)

        template_data = read_data.get('basecall_1d_callback', read_data.get('basecall_template_callback', None))
        complement_data = read_data.get('basecall_complement_callback', None)
        sections = {}
        if template_data is not None:
            sections['template'] = template_data
        if complement_data is not None:
            sections['complement'] = complement_data
        if len(sections) == 0:
            return {'read_id': read_data['read_id']}
        for section, section_data in sections.items():
            labllen = section_data['label_len']
        event_field_names = ['mean', 'stdv', 'start', 'length',
                             'model_state', 'move', 'weights', 'p_model_state', 'mp_state', 'p_mp_state',
                             'p_A', 'p_C', 'p_G', 'p_T']
        labl_dtype = '<S{}'.format(labllen)
        event_field_dtypes = ['<f4', '<f4', '<u8', '<u8'] + [labl_dtype, '<i4'] + ['<f4'] * 2 + [labl_dtype] + ['<f4'] * 5
        data = {}
        fastq_label = self.data_id_mapping[data_id].label
        all_labels = list(map(lambda x: ''.join(x), itertools.product('ACGT', repeat=labllen)))
        data = {'read_id': read_data['read_id']}
        for section, section_data in sections.items():
            num_called = section_data['mean'].size
            data['num_events_{}'.format(section)] = num_events[section]
            data['num_called_{}'.format(section)] = num_called
            dataset = np.empty(num_called, dtype=list(zip(event_field_names, event_field_dtypes)))
            for field, the_dtype in zip(event_field_names, event_field_dtypes):
                if the_dtype.startswith('<S'):
                    # hdf5 doesn't like unicode, so we need to convert in python3
                    dataset[field][:] = [all_labels[state] for state in section_data[field]]
                else:
                    dataset[field][:] = section_data[field]
            data['strand_score_{}'.format(section)] = round(section_data['call_score'], 4)
            data['stay_prob_{}'.format(section)] = round(section_data['stay_prob'], 4)
            data['step_prob_{}'.format(section)] = round(section_data['step_prob'], 4)
            data['skip_prob_{}'.format(section)] = round(section_data['skip_prob'], 4)
            data['{}_data'.format(section)] = dataset
            data['{}_fastq'.format(section)] = _make_fastq(template_data['read_id'], 'Basecall_1D_{}'.format(section),
                                                           fastq_label, section_data['sequence'], section_data['qstring'],
                                                           reverse_sequence=self.reverse_direction,
                                                           u_substitution=self.u_substitution)
            data['mean_qscore_{}'.format(section)] = round(section_data['mean_qscore'], 3)
            data['sequence_length_{}'.format(section)] = section_data['sequence_length']
            data['median_current_{}'.format(section)] = section_data['median_current']
            data['median_sd_{}'.format(section)] = section_data['median_sd']
        return data

    def _extract_basecall_2d_data(self, data_id, read_data):
        call_2d_data = read_data.get('basecall_2d_callback', None)
        if call_2d_data is None:
            return {'read_id': read_data['read_id']}
        sequence = call_2d_data.get('sequence')
        qstring = call_2d_data.get('qstring')
        alignment = call_2d_data.get('alignment')
        states = call_2d_data.get('model_state')
        moves = call_2d_data.get('movement')
        labllen = call_2d_data['label_len']
        all_labels = list(map(lambda x: ''.join(x), itertools.product('ACGT', repeat=labllen)))
        labls = [all_labels[state] for state in states]
        if sequence is None or len(sequence) == 0:
            return {'read_id': read_data['read_id']}
        fastq_label = self.data_id_mapping[data_id].label
        data = {'read_id': read_data['read_id'],
                'mean_qscore_2d': round(call_2d_data.get('mean_qscore', 0.0), 3),
                'sequence_length_2d': len(sequence),
                '2d_alignment_data': alignment,
                '2d_labl_data': labls,
                '2d_movement': moves,
                '2d_fastq': _make_fastq(call_2d_data['read_id'], 'Basecall_2D', fastq_label, sequence, qstring,
                                        reverse_sequence=False, u_substitution=self.u_substitution)}
        return data

    def _extract_calib_data(self, data_id, read_data):
        read_info = self.data_id_mapping[data_id]
        sections = {'1d', 'template', 'complement', '2d'}
        calib_data = {section: read_data.get('calib_detector_summary_{}'.format(section), None) for section in sections}
        # delete unused keys
        calib_data = {key: content for key, content in calib_data.items() if content}

        if '1d' in calib_data:
            if 'template' in calib_data:
                raise Exception("If 1D calibration strand detector data is present, "
                                "no other strand section is allowed.")
            else:
                calib_data['template'] = calib_data.pop('1d')
        return deepcopy(calib_data)


    def _extract_barcoding_data(self, data_id, read_data):
        barcode_data = read_data.get('barcode_summary', None)
        if barcode_data is None:
            return {}
        data = deepcopy(barcode_data)
        data.pop('channel', None) # 'Channel' is worker_id this can incorrectly overwrite the 'channel_id' - TUNA-162

        # except for three special cases, we want to strip "barcode_" from the keys
        keys = deepcopy(list(data.keys()))
        for key in keys:
            if key.startswith('barcode_'):
                key_wo_barcoding = key[len('barcode_'):]
                if key_wo_barcoding not in ['arrangement', 'full_arrangement', 'score']:
                    data[key_wo_barcoding] = data.pop(key)
        return data

    def _extract_alignment_data(self, data_id, read_data):
        sections = {'1d', 'template', 'complement', '2d'}
        alignment_data = {section: read_data.get('aligner_summary_{}'.format(section), None) for section in sections}
        # delete unused keys
        alignment_data = {key: content for key, content in alignment_data.items() if content}

        if '1d' in alignment_data:
            if 'template' in alignment_data:
                raise Exception("If 1D alignment data is present, "
                                "no other strand section is allowed.")
            else:
                alignment_data['template'] = alignment_data.pop('1d')
        return deepcopy(alignment_data)


def _make_fastq(read_id, basename, label, sequence, qstring, reverse_sequence=False, u_substitution=False):
    seq = ''.join(sequence)
    qs = ''.join(qstring)
    if reverse_sequence:
        seq = seq[::-1]
        qs = qs[::-1]
    if u_substitution:
        seq = seq.replace("T", "U")
    return {'name': '{}_{} {}'.format(read_id, basename, label), 'sequence': seq, 'qstring': qs}
