# Copyright (c) 2013, Ian Reid, Concordia University Centre for Structural and Functional Genomics
# All rights reserved.

'''
From a counted.origins file and a .gff3 file giving transcript positions, calculate the empirical probability of
fragmentation at each position in each transcript
Created on Apr 15, 2010
@author: ian
Modified to produce Stranded_Read_Initiation_Probability_Vectors 2012-12-07
'''

import sys
import os
this_dir = os.path.dirname(__file__)
src = os.path.dirname(this_dir)
sys.path.append(src)
from lib.gff3Iterator import GFF3Iterator
import shelve
from pysam import Tabixfile


def usage():
    print 'Usage: python %s transcripts.gff3  counted.origins.txt.gz  output.shelf' % sys.argv[0]
    sys.exit(1)


class origin_count_line:
    def __init__(self, seq_id, position, left_count, right_count):
        self.seq_id = seq_id
        self.position = int(position)
        self.left_count = float(left_count)
        self.right_count = float(right_count)

    def __str__(self):
        return '%s:%d %f, %f' % (self.seq_id, self.position, self.left_count, self.right_count)

    def __repr__(self):
        return '%s\t%d\t%f\t%f' % (self.seq_id, self.position, self.left_count, self.right_count)

    @staticmethod
    def fromLine(line):
        fields = line.strip().split('\t')
        if len(fields) < 4:
            return None
        return origin_count_line(fields[0], int(fields[1]), float(fields[2]), float(fields[3]))


class Stranded_Read_Initiation_Probability_Vector(object):
    """
    Probabilities of read initiation in forward and backward directions at each position in a DNA sequence
    """

    def __init__(self, _plus, _minus):
        object.__init__(self)
        self.plus = _plus
        self.minus = _minus

    @classmethod
    def empty(cls):
        return cls([], [])

    def __len__(self):
        return min(len(self.plus), len(self.minus))

    def __repr__(self):
        return "Stranded_Read_Initiation_Probability_Vector(%s, %s)" % (
        concise_list_repr(self.plus), concise_list_repr(self.minus))

    def __str__(self):
        return self.__repr__()

    def __getitem__(self, item):
        return Stranded_Read_Initiation_Probability_Vector(self.plus[item], self.minus[item])

    def extend(self, other):
        self.plus.extend(other.plus)
        self.minus.extend(other.minus)

    def scale(self, factor):
        self.plus = [p * factor for p in self.plus]
        self.minus = [p * factor for p in self.minus]

    def clone(self):
        return Stranded_Read_Initiation_Probability_Vector(_plus=self.plus[:], _minus=self.minus[:])

    def switch_strand(self):
        self.minus, self.plus = (self.plus, self.minus)
        self.minus.reverse()
        self.plus.reverse()


def concise_list_repr(alist):
    return '[%s]' % (','.join(['%.2g' % v for v in alist]))


def get_origin_counts_for_exon(exon, ocl_iter):
    ''' exon is a GFF3Exon instance.
    ocl_iter is a Tabixfile.'''
    try:
        counts = [origin_count_line.fromLine(line) for line in
                  ocl_iter.fetch(exon.get_seqID(), exon.get_start(), exon.get_end())]
    except ValueError, ve:
        print >> sys.stderr, ve
        counts = []
    return counts


def get_exon_origin_counts_list(exon, origin_count_tabix, len_probs):
    exon_counts = []
    for ocl in get_origin_counts_for_exon(exon, origin_count_tabix):
        ocl.position += len_probs - exon.get_start()
        exon_counts.append(ocl)
    return exon_counts


def compensate_left_flank_counts(counts, avg_count):
    # compensate for selection bias against terminal fragments
    left_flank = 0
    i = 0
    for i in range(len(counts) / 2):
        if counts[i].left_count >= avg_count and counts[i].left_count >= counts[i].right_count:
            break
    left_flank = counts[i].position
    if left_flank > 0:
        slope = 1.0 / left_flank
        for j in range(i - 1, -1, -1):
            if counts[j].left_count > 0:
                factor = 1 + slope * (left_flank - counts[j].position)
                counts[j].left_count *= factor
    return counts


def compensate_right_flank_counts(counts, avg_count, len_probs):
    right_flank = len_probs - 1
    i = len(counts) - 1
    for i in range(len(counts) - 1, len(counts) / 2, -1):
        if counts[i].right_count >= avg_count and counts[i].left_count <= counts[i].right_count:
            break
    right_flank = counts[i].position
    if right_flank < len_probs - 1:
        slope = 1.0 / (len_probs - 1 - right_flank)
        for j in range(i, len(counts)):
            if counts[j].right_count > 0:
                factor = 1 + slope * (counts[j].position - right_flank)
                counts[j].right_count *= factor
    return counts


def calc_probs_from_counts(counts, probs):
    minus_total = sum([ocl.left_count for ocl in counts]) + sum(probs.minus)
    plus_total = sum([ocl.right_count for ocl in counts]) + sum(probs.plus)
    if minus_total > 0:
        # prescaling to give final sum(probs) == 1
        probs.minus = [pc / minus_total for pc in probs.minus]
        for ocl in counts:
            probs.minus[ocl.position] += ocl.left_count / minus_total
    if plus_total > 0:
        # prescaling to give final sum(probs) == 1
        probs.plus = [pc / plus_total for pc in probs.plus]
        for ocl in counts:
            probs.plus[ocl.position] += (ocl.right_count) / plus_total
    return probs


def get_probs_for_transcript(transcript, origin_count_tabix):
    ''' transcript is a GFF3mRNA instance.
    origin_count_tabix is a tabix-indexed file of origin_count_lines.'''
    probs = Stranded_Read_Initiation_Probability_Vector.empty()
    counts = []
    total = 0
    transcript_len = 0
    for exon in transcript.get_exons():
        counts.extend(get_exon_origin_counts_list(exon, origin_count_tabix, transcript_len))
        transcript_len += len(exon)
    if counts:
        pseudocount = 1.0 / transcript_len
        probs.minus = [pseudocount] * transcript_len
        probs.plus = [pseudocount] * transcript_len
        probs = calc_probs_from_counts(counts, probs)
    if transcript.get_strand() == '-':
        probs.switch_strand()
    return probs


def condense_probs(probs):
    return [(i, p) for i, p in enumerate(probs) if p > 0]


def expand_probs(tuple_list, target_length):
    length = target_length
    if tuple_list:
        length = max(tuple_list[-1][0] + 1, target_length)
    probs = [0] * length
    for i, p in tuple_list:
        probs[i] = p
    return probs


def do_calcTranscriptFragmentationProbabilities(gff3, counts_tabix, prob_vectors):
    genes = GFF3Iterator(gff3).genes()
    for gene in genes:
        for transcript in gene.get_transcripts():
            probs = get_probs_for_transcript(transcript, counts_tabix)
#            print transcript.get_ID(), sum(probs.minus), sum(probs.plus)
            prob_vectors[transcript.get_ID()] = probs


def calcTranscriptFragmentationProbabilities_main(argv=sys.argv):
    try:
        gff3 = open(argv[1])
        counts_tabix = Tabixfile(argv[2])
        prob_vectors = shelve.open(argv[3], 'c')
        do_calcTranscriptFragmentationProbabilities(gff3, counts_tabix, prob_vectors)
        prob_vectors.close()
    except IndexError:
        usage()



if __name__ == '__main__':
    calcTranscriptFragmentationProbabilities_main(sys.argv)
    print >> sys.stderr, sys.argv[0], 'done.'

