# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

"""For a gene model with RNA-Seq read coverage data, derive a set of splicing isoforms that can reproduce the pattern
 of intron readthrough shown by the read coverage

Created: 2012-04-18
Author: ian
Modified to use Stranded_Read_Initiation_Probability_Vector 2012-12-07
"""
import sys
import os
import math

this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from lib.gff3Record import GFF3Gene
from lib.gff3Iterator import GFF3Iterator
from lib.find_read_islands import get_cover_blocks, get_mean_depth
from calcTranscriptFragmentationProbabilities import get_exon_origin_counts_list, \
    calc_probs_from_counts, Stranded_Read_Initiation_Probability_Vector
from pysam import Tabixfile
from collections import namedtuple
import shelve
import argparse

DESCRIPTION = 'For a gene model with RNA-Seq read coverage data, derive a set of splicing isoforms that can reproduce' \
              ' the pattern of intron readthrough shown by the read coverage'
VERSION = '0.1'

SegmentProbabilityProfile = namedtuple('SegmentProbabilityProfile',
                                       ['scaffold', 'start', 'end', 'depth', 'prob_vector'])


def usage(msg=None):
    if msg:
        print >> sys.stderr, msg
    print >> sys.stderr, 'Usage: python %s  genes.gff3  coverage.wig.gz   counted.origins.txt.gz output.gff3 ' \
                         'probabilities.shelf' % sys.argv[0]
    sys.exit(1)


def get_introns(exons):
    introns = []
    if len(exons) > 1:
        for exon1, exon2 in zip(exons[:-1], exons[1:]):
            new_intron = exon1.clone()
            new_intron.set_start(exon1.get_end() + 1)
            new_intron.set_end(exon2.get_start() - 1)
            new_intron.set_type('intron')
            introns.append(new_intron)
    return introns


def get_segment_mean_depths(segments, coverage):
    segment_depths = []
    for segment in segments:
        blocks = get_cover_blocks(segment, coverage)
        depth = get_mean_depth(blocks)
        segment_depths.append(depth)
    return segment_depths


def get_segment_probability_profile(segment, depth, origin_count_iter):
    counts = get_exon_origin_counts_list(segment, origin_count_iter, 0)
    pseudocount = 1.0 / len(segment)
    probs = Stranded_Read_Initiation_Probability_Vector(_plus=[pseudocount] * len(segment),
                                                        _minus=[pseudocount] * len(segment))
    probs = calc_probs_from_counts(counts, probs)
    return SegmentProbabilityProfile(segment.get_seqID(), segment.get_start(), segment.get_end(), depth, probs)


def concatenate_probability_profiles(profile_list, slice_height, weights):
    total_weight = sum(
        [(1 + profile_list[i].end - profile_list[i].start) * weights[i] for i in range(len(profile_list))])
    result = Stranded_Read_Initiation_Probability_Vector.empty()
    for p, w in zip(profile_list, weights):
        p_len = 1 + p.end - p.start
        scale_factor = float(p_len) * w / total_weight
        #        scale_factor = float(p_len) / total_length
        scaled_prob_vector = p.prob_vector.clone()
        scaled_prob_vector.scale(scale_factor)
        result.extend(scaled_prob_vector)
    return result


def create_isoform(isoform_count, transcript1, slice_height, used_segs):
    # create new isoform transcript
    new_isoform = transcript1.clone()
    new_isoform.add_attribute('depth', slice_height)
    new_isoform.set_ID("%s_%d" % (transcript1.get_ID(), isoform_count))
    new_isoform.set_start(min([s.get_start() for s in used_segs]))
    new_isoform.set_end(max([s.get_end() for s in used_segs]))
    for exon in merge_abutting(used_segs):
        new_exon = exon.clone()
        new_exon.clear_parents()
        new_isoform.add_exon(new_exon)

    return new_isoform


def merge_abutting(seg_list):
    result = []
    current_exon = seg_list[0].clone()
    for seg in seg_list[1:]:
        if seg.get_start() <= current_exon.get_end() + 1:
            current_exon.set_end(seg.get_end())
        else:
            result.append(current_exon)
            current_exon = seg.clone()
    result.append(current_exon)
    return result


def add_1_isoform(isoform_count, segments, seg_depths, seg_profiles, total_height, slice_height, transcript1, gene,
                  all_weights):
    isoform_count += 1
    used_segs = [segments[i] for i in range(len(segments)) if seg_depths[i] > total_height]
    used_profiles = [seg_profiles[i] for i in range(len(segments)) if seg_depths[i] > total_height]
    weights = [all_weights[i] for i in range(len(segments)) if seg_depths[i] > total_height]
    new_isoform = create_isoform(isoform_count, transcript1, slice_height, used_segs)
    prob_vector = concatenate_probability_profiles(used_profiles, slice_height, weights)
    saved_prob_vectors[new_isoform.get_ID()] = prob_vector
    gene.add_transcript(new_isoform)
    return isoform_count


def add_isoforms(gene, coverage, origin_count_iter, saved_prob_vectors):
    '''

    :param gene: A gene record populated with a transcript and exon(s)
    :type gene: GFF3Gene
    :param coverage: read coverage depths, tabix-indexed
    :type coverage: Tabixfile
    :param origin_count_iter: Origin counts list, tabix-indexed
    :type origin_count_iter: Tabixfile
    :param saved_prob_vectors: output shelf file for read origin probability vectors
    :type saved_prob_vectors: shelve
    :return: Gene record with artificial isoforms as transcripts
    :rtype: GFF3Gene
    '''
    transcript0 = gene.get_transcripts()[0]
    exons0 = sorted(transcript0.get_exons())
    introns = get_introns(exons0)
    if introns:
        segments = []
        for i in range(len(introns)):
            segments.append(exons0[i])
            segments.append(introns[i])
        segments.append(exons0[-1])
        gene.set_transcripts([])
        seg_depths = get_segment_mean_depths(segments, coverage)
        max_seg_depth = max(seg_depths)
        max_intron_depth = max([seg_depths[i] for i in range(len(seg_depths)) if segments[i].get_type() == 'intron'])
        seg_profiles = [get_segment_probability_profile(segment, depth, origin_count_iter) for segment, depth in
                        zip(segments, seg_depths)]

        # start with monoexonic transcript and add introns one-by-one
        transcript1 = transcript0.clone()
        transcript1.clear_exons()
        isoform_count = 0
        total_height = 0
        while total_height < max_intron_depth:
            slice_height = int(math.ceil(min([d for d in seg_depths if d > total_height]))) - total_height
            if slice_height > 0:
                weights = [min(1.0, (float(d) - total_height) / slice_height) for d in seg_depths]
                isoform_count = add_1_isoform(isoform_count, segments, seg_depths, seg_profiles, total_height,
                                              slice_height, transcript1, gene, weights)
            total_height += slice_height
        if total_height < max_seg_depth: # top layer for residual exon depth
            slice_height = int(math.ceil(max_seg_depth - total_height))
            if slice_height > 0:
                weights = [(float(d) - total_height) / slice_height for d in seg_depths]
                isoform_count = add_1_isoform(isoform_count, segments, seg_depths, seg_profiles, total_height,
                                              slice_height, transcript1, gene, weights)

    else: #monoexonic gene
        slice_height = 1
        transcript0.add_attribute('depth', 1)
        probs = get_segment_probability_profile(exons0[0], 1, origin_count_iter)
        prob_vector = concatenate_probability_profiles([probs], slice_height, [1])
        saved_prob_vectors[transcript0.get_ID()] = prob_vector
    return gene


def get_args():
    argparser = argparse.ArgumentParser(description=DESCRIPTION)
    # standard options
    argparser.add_argument('--version', action='version', version='%(prog)s' + VERSION)
    argparser.add_argument('--verbose', '-v', action='count', default=0,
                           help='Omit to see only fatal error messages; -v to see warnings; -vv to see warnings and '
                                'progress messages')
    # options to customize
    argparser.add_argument('--in', '-i', dest='input', type=argparse.FileType('r'), nargs='?', default=sys.stdin,
                           help='Path to the genes.gff3 input file; if omitted or -, input is read from stdin')
    argparser.add_argument('--out', '-o', type=argparse.FileType('w'), nargs='?', default=sys.stdout,
                           help='Path to the isoforms.gff3 output file; if omitted or -, output is written to stdout')
    argparser.add_argument('--coverage', required=True, help='Path to a tabix-indexed coverage.wig.gz file; required')
    argparser.add_argument('--counts', required=True, help='Path to a tabix-indexed origin_counts.gz file; required')
    argparser.add_argument('--shelf', '-s', required=True,
                           help='Path to a shelve file for saving read initiation probabilities; required')
    return argparser.parse_args()


if __name__ == '__main__':
    args = get_args()
    genes = GFF3Iterator(args.input).genes()
    coverage = Tabixfile(args.coverage)
    origin_count_iter = Tabixfile(args.counts)
    output = args.out
    saved_prob_vectors = shelve.open(args.shelf, 'c')

    for gene in genes:
        try:
            gene = add_isoforms(gene, coverage, origin_count_iter, saved_prob_vectors)
            print >> output, gene
        except Exception, e:
            print >> sys.stderr, e, gene.get_ID()

    output.close()
    saved_prob_vectors.close()
    print >> sys.stderr, sys.argv[0], 'done.'
