##################################
#                                #
# Last modified 01/31/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import scipy.stats
from sets import Set
import time

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s sample.juncs annotations.juncs copies_per_cell_file geneFieldID copies_per_cell_fieldID catch_rate evaluation_limit_factor copies_limit_cap p_value_cutoff outfile' % sys.argv[0]
        print '\tassumed annotaiton.juncs format:'
        print '\t#chr\tleft\tright\tstrand\tGeneID(s)\tGeneName(s)\tTranscriptID(s)\tTranscriptName(s)'
        print '\tassumed sample.juncs format:'
        print '\t#chr\tleft\tright\tstrand\ttotal_counts\tstaggered_counts'
        print '\tcatch rate paramter should be a float, for example 0.10'
        print '\tNote: multiple entires for the same genee in the copies_per_cell file will be summed'
        print '\tevaluation_limit_factor refers to the number up to which catch rate p-value estimation should be carried out, which will be the evaluation_limit_factor times the number of copies per cell'
        print '\tfor large number of copies, the calculaiton becomesi impractical; the minimum of the copies_limit_cap or the evaluation_limit_factor times the number of copies per cell will be used'
        print '\tuse the p-value cutoff in order to save computational time; genes for which either the read or copies p-values are above that cutoff will not be evaluated for catch rate effects'
        sys.exit(1)

    juncs=sys.argv[1]
    annotated_juncs=sys.argv[2]
    copies=sys.argv[3]
    geneID = int(sys.argv[4])
    copiesID = int(sys.argv[5])
    catch_rate = float(sys.argv[6])
    evaluation_factor = float(sys.argv[7])
    copies_limit_cap = int(sys.argv[8])
    p_value_cutoff = float(sys.argv[9])
    outfilename = sys.argv[10]

    CopiesDict={}

    linelist=open(copies)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields=line.strip().split('\t')
        gene = fields[geneID]
        copies = float(fields[copiesID])
        if CopiesDict.has_key(gene):
            pass
        else:
            CopiesDict[gene]=0
        CopiesDict[gene] += copies

    print 'finished inputting copies-per-cell estimates'

    SpliceSite5pDict={}
    SpliceSite3pDict={}
    GeneTranscriptSpliceSiteDict={}

    linelist=open(annotated_juncs)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        geneID = fields[4]
        geneName = fields[5]
        transcriptID = fields[6]
        transcriptName = fields[7]
        if strand == '+':
            splice5p = left
            splice3p = right
        if strand == '-':
            splice3p = left
            splice5p = right
        if SpliceSite5pDict.has_key((chr,splice5p,strand)):
            pass
        else:
            SpliceSite5pDict[(chr,splice5p,strand)] = {}
        SpliceSite5pDict[(chr,splice5p,strand)][splice3p]={}
        SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts'] = 0
        SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'] = (geneID,geneName,transcriptID,transcriptName)
        if SpliceSite3pDict.has_key((chr,splice3p,strand)):
            pass
        else:
            SpliceSite3pDict[(chr,splice3p,strand)] = {}
        SpliceSite3pDict[(chr,splice3p,strand)][splice5p]={}
        SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts'] = 0
        SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'] = (geneID,geneName,transcriptID,transcriptName)
        if GeneTranscriptSpliceSiteDict.has_key((geneID,geneName)):
            pass
        else:
            GeneTranscriptSpliceSiteDict[(geneID,geneName)] = {}
        if GeneTranscriptSpliceSiteDict[(geneID,geneName)].has_key((transcriptID,transcriptName)):
            pass
        else:
            GeneTranscriptSpliceSiteDict[(geneID,geneName)][(transcriptID,transcriptName)]={}
            GeneTranscriptSpliceSiteDict[(geneID,geneName)][(transcriptID,transcriptName)]['5p'] = {}
            GeneTranscriptSpliceSiteDict[(geneID,geneName)][(transcriptID,transcriptName)]['3p'] = {}
        GeneTranscriptSpliceSiteDict[(geneID,geneName)][(transcriptID,transcriptName)]['5p'][(chr,splice5p,strand)]=1
        GeneTranscriptSpliceSiteDict[(geneID,geneName)][(transcriptID,transcriptName)]['3p'][(chr,splice3p,strand)]=1

    print 'finished importing annotated junctions'

    linelist=open(juncs)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        total_counts = int(fields[4])
        staggered_counts = int(fields[5])
        if strand == '+':
            splice5p = left
            splice3p = right
        if strand == '-':
            splice3p = left
            splice5p = right
        if SpliceSite5pDict.has_key((chr,splice5p,strand)):
            geneID = SpliceSite5pDict[(chr,splice5p,strand)][SpliceSite5pDict[(chr,splice5p,strand)].keys()[0]]['annotation'][0]
            geneName = SpliceSite5pDict[(chr,splice5p,strand)][SpliceSite5pDict[(chr,splice5p,strand)].keys()[0]]['annotation'][1]
            if SpliceSite3pDict.has_key((chr,splice3p,strand)):
                if SpliceSite5pDict[(chr,splice5p,strand)].has_key(splice3p):
                    SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts'] = staggered_counts
                    SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts'] = staggered_counts
                else:
                    SpliceSite5pDict[(chr,splice5p,strand)][splice3p]={}
                    SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts'] = staggered_counts
                    SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'] = (geneID,geneName,'novel','novel')
                    SpliceSite3pDict[(chr,splice3p,strand)][splice5p]={}
                    SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts'] = staggered_counts
                    SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'] = (geneID,geneName,'novel','novel')
            else:
                SpliceSite5pDict[(chr,splice5p,strand)][splice3p]={}
                SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts'] = staggered_counts
                SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'] = (geneID,geneName,'novel','novel')
        else:
            if SpliceSite3pDict.has_key((chr,splice3p,strand)):
                SpliceSite3pDict[(chr,splice3p,strand)][splice5p]={}
                SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts'] = staggered_counts
                SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'] = (geneID,geneName,'novel','novel')
            else:
                continue

    print 'finished importing detected junctions'

    outfile = open(outfilename, 'w')
    outline = '#chr\tleft\tright\tstrand\t5p_or_3p\tgeneID\tgeneName\ttranscriptID\ttranscriptName\tstaggered_counts\talternative_sites\taltternative_counts\tfraction_reads\tannotation_status_1\tannotation_status_2\treads_p-value\tcopies\talternative_copies\tcopies_p-value\tcatch_rate_p-value'
    outfile.write(outline + '\n')

    splices5p = SpliceSite5pDict.keys()
    splices5p.sort()
    splices3p = SpliceSite3pDict.keys()
    splices3p.sort()

    CatchRatePValueDict={}

    g=0
    for (chr,splice3p,strand) in splices3p:
        if len(SpliceSite3pDict[(chr,splice3p,strand)].keys()) == 1:
            continue
        g+=1
        print '3p', g
        TotalCounts = 0.0
        MajorSpliceCounts = 0
        genes = []
        transcripts = []
        for splice5p in SpliceSite3pDict[(chr,splice3p,strand)].keys():
            if SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts'] > MajorSpliceCounts:
                MajorSpliceCounts = SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts']
            TotalCounts += SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts']
            (geneID,geneName,transcriptID,transcriptName) = SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation']
            genes.append((geneID,geneName))
            transcripts.append((transcriptID,transcriptName))
        genes = list(Set(genes))
        AnnotationStatus1 = 'OK'
        AnnotationStatus2 = 'OK'
        MajorIsoformWithMoreThanHalfOfReads = False
        MultipleGenes = False
        if len(genes) > 1:
            MultipleGenes = True
            AnnotationStatus1 = 'multiple_genes'
        NotInAllModels = False
        transcripts = list(Set(transcripts))
        for (geneID,geneName) in genes:
            if len(transcripts) != len(GeneTranscriptSpliceSiteDict[(geneID,geneName)].keys()):
                NotInAllModels = True
                AnnotationStatus2 = 'not_in_all_models'
        if TotalCounts == 0 or MajorSpliceCounts/TotalCounts < 0.5:
            for splice5p in SpliceSite3pDict[(chr,splice3p,strand)].keys():
                counts = SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts']
                if strand == '-':
                    outline = chr + '\t' + str(splice3p)
                    outline = outline + '\t' + str(splice5p) + '\t' + strand + '\t5p'
                else:
                    outline = chr + '\t' + str(splice5p)
                    outline = outline + '\t' + str(splice3p) + '\t' + strand + '\t5p'
                outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][0]
                outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][1]
                outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][2]
                outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][3] + '\t' + str(counts)
                outline = outline + '\t' + str(len(SpliceSite3pDict[(chr,splice3p,strand)].keys()))
                outline = outline + '\t' + str(TotalCounts - counts)
                if TotalCounts == 0:
                    outline = outline + '\t' + str(0)
                else:
                    outline = outline + '\t' + str(counts/TotalCounts)
                outline = outline + '\t' + AnnotationStatus1
                outline = outline + '\t' + AnnotationStatus2
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outfile.write(outline + '\n')
            continue
        gene = geneName
        try:
            copies = CopiesDict[gene]
        except:
            print gene, 'not found in expression file, skipping'
            continue
        for splice5p in SpliceSite3pDict[(chr,splice3p,strand)].keys():
            counts = SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['counts']
            reads_p_value = scipy.stats.binom_test(counts, TotalCounts, 0.5)
            splice_copies = int((counts/TotalCounts)*copies)
            alt_copies = int(copies - splice_copies)
            if gene == 'PFDN2':
                print copies, counts, counts/TotalCounts, (counts/TotalCounts)*copies, splice_copies, alt_copies
            copies_p_value =  scipy.stats.binom_test(splice_copies, alt_copies + splice_copies, 0.5)
            copies_limit = min(int(evaluation_factor*copies),copies_limit_cap)
            CEa = min(splice_copies,alt_copies)
            CEA = max(splice_copies,alt_copies)
            catch_rate_pvalue = 0
            start = time.time()
            if copies_p_value >= p_value_cutoff or reads_p_value >= p_value_cutoff:
                catch_rate_pvalue = 'not_calculated'
            else:
                if CatchRatePValueDict.has_key((CEa,CEA)):
                    catch_rate_pvalue = CatchRatePValueDict[(CEa,CEA)]
                else:
                    for CCa in range(int(copies),copies_limit):
                        rvNB = scipy.stats.nbinom(copies,catch_rate)
                        NB = rvNB.pmf(CCa - int(copies))
                        if NB == 0:
                            continue
                        rvCEa = scipy.stats.binom(int(0.5*CCa),catch_rate)
                        i=0
                        p1 = rvCEa.pmf(i)
                        if p1 == 0:
                            pass
                        else:
                            for j in range(CEA,max(CEA,int(0.5*CCa))):
                                p2 = rvCEa.pmf(j)
                                if p2 != 0:
                                    catch_rate_pvalue += NB*p1*p2
                        for i in range(1,CEa+1):
                            p1 = rvCEa.pmf(i)
                            if p1 == 0:
                                continue
                            for j in range(CEA,max(CEA,int(0.5*CCa))):
                                p2 = rvCEa.pmf(j)
                                if p2 != 0:
                                   catch_rate_pvalue += NB*p1*p2
                CatchRatePValueDict[(CEa,CEA)] = catch_rate_pvalue
            if strand == '-':
                outline = chr + '\t' + str(splice3p)
                outline = outline + '\t' + str(splice5p) + '\t' + strand + '\t5p'
            else:
                outline = chr + '\t' + str(splice5p)
                outline = outline + '\t' + str(splice3p) + '\t' + strand + '\t5p'
            outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][0]
            outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][1]
            outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][2]
            outline = outline + '\t' + SpliceSite3pDict[(chr,splice3p,strand)][splice5p]['annotation'][3] + '\t' + str(counts)
            outline = outline + '\t' + str(len(SpliceSite3pDict[(chr,splice3p,strand)].keys()))
            outline = outline + '\t' + str(TotalCounts - counts)
            if TotalCounts == 0:
                outline = outline + '\t' + str(0)
            else:
                outline = outline + '\t' + str(counts/TotalCounts)
            outline = outline + '\t' + AnnotationStatus1
            outline = outline + '\t' + AnnotationStatus2
            outline = outline + '\t' + str(reads_p_value)
            outline = outline + '\t' + str(splice_copies)
            outline = outline + '\t' + str(alt_copies)
            outline = outline + '\t' + str(copies_p_value)
            outline = outline + '\t' + str(catch_rate_pvalue)
            outfile.write(outline + '\n')
    for (chr,splice5p,strand) in splices5p:
        if len(SpliceSite5pDict[(chr,splice5p,strand)].keys()) == 1:
            continue
        g+=1
        print '5p', g
        TotalCounts = 0.0
        MajorSpliceCounts = 0
        genes = []
        transcripts = []
        for splice3p in SpliceSite5pDict[(chr,splice5p,strand)].keys():
            if SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts'] > MajorSpliceCounts:
                MajorSpliceCounts = SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts']
            TotalCounts += SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts']
            (geneID,geneName,transcriptID,transcriptName) = SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation']
            genes.append((geneID,geneName))
            transcripts.append((transcriptID,transcriptName))
        genes = list(Set(genes))
        AnnotationStatus1 = 'OK'
        AnnotationStatus2 = 'OK'
        MajorIsoformWithMoreThanHalfOfReads = False
        MultipleGenes = False
        if len(genes) > 1:
            MultipleGenes = True
            AnnotationStatus1 = 'multiple_genes'
        NotInAllModels = False
        transcripts = list(Set(transcripts))
        for (geneID,geneName) in genes:
            if len(transcripts) != len(GeneTranscriptSpliceSiteDict[(geneID,geneName)].keys()):
                NotInAllModels = True
                AnnotationStatus2 = 'not_in_all_models'
        if TotalCounts == 0 or MajorSpliceCounts/TotalCounts < 0.5:
            for splice3p in SpliceSite5pDict[(chr,splice5p,strand)].keys():
                counts = SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts']
                if strand == '-':
                    outline = chr + '\t' + str(splice3p)
                    outline = outline + '\t' + str(splice5p) + '\t' + strand + '\t3p'
                else:
                    outline = chr + '\t' + str(splice5p)
                    outline = outline + '\t' + str(splice3p) + '\t' + strand + '\t3p'
                outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][0]
                outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][1]
                outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][2]
                outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][3] + '\t' + str(counts)
                outline = outline + '\t' + str(len(SpliceSite5pDict[(chr,splice5p,strand)].keys()))
                outline = outline + '\t' + str(TotalCounts - counts)
                if TotalCounts == 0:
                    outline = outline + '\t' + str(0)
                else:
                    outline = outline + '\t' + str(counts/TotalCounts)
                outline = outline + '\t' + AnnotationStatus1
                outline = outline + '\t' + AnnotationStatus2
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outline = outline + '\t' + 'no_dominant_splice'
                outfile.write(outline + '\n')
            continue
        gene = geneName
        try:
            copies = CopiesDict[gene]
        except:
            print gene, 'not found in expression file, skipping'
            continue
        for splice3p in SpliceSite5pDict[(chr,splice5p,strand)].keys():
            counts = SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['counts']
            reads_p_value = scipy.stats.binom_test(counts, TotalCounts, 0.5)
            splice_copies = int((counts/TotalCounts)*copies)
            alt_copies = int(copies - splice_copies)
            if gene == 'PFDN2':
                print copies, counts, counts/TotalCounts, (counts/TotalCounts)*copies, splice_copies, alt_copies
            copies_p_value =  scipy.stats.binom_test(splice_copies, alt_copies + splice_copies, 0.5)
            copies_limit = min(int(evaluation_factor*copies),copies_limit_cap)
            CEa = min(splice_copies,alt_copies)
            CEA = max(splice_copies,alt_copies)
            catch_rate_pvalue = 0
            start = time.time()
            if copies_p_value >= p_value_cutoff or reads_p_value >= p_value_cutoff:
                catch_rate_pvalue = 'not_calculated'
            else:
                if CatchRatePValueDict.has_key((CEa,CEA)):
                    catch_rate_pvalue = CatchRatePValueDict[(CEa,CEA)]
                else:
                    for CCa in range(int(copies),copies_limit):
                        rvNB = scipy.stats.nbinom(copies,catch_rate)
                        NB = rvNB.pmf(CCa - int(copies))
                        if NB == 0:
                            continue
                        rvCEa = scipy.stats.binom(int(0.5*CCa),catch_rate)
                        i=0
                        p1 = rvCEa.pmf(i)
                        if p1 == 0:
                            pass
                        else:
                            for j in range(CEA,max(CEA,int(0.5*CCa))):
                                p2 = rvCEa.pmf(j)
                                if p2 != 0:
                                    catch_rate_pvalue += NB*p1*p2
                        for i in range(1,CEa+1):
                            p1 = rvCEa.pmf(i)
                            if p1 == 0:
                                continue
                            for j in range(CEA,max(CEA,int(0.5*CCa))):
                                p2 = rvCEa.pmf(j)
                                if p2 != 0:
                                   catch_rate_pvalue += NB*p1*p2
                CatchRatePValueDict[(CEa,CEA)] = catch_rate_pvalue
            if strand == '-':
                outline = chr + '\t' + str(splice3p)
                outline = outline + '\t' + str(splice5p) + '\t' + strand + '\t3p'
            else:
                outline = chr + '\t' + str(splice5p)
                outline = outline + '\t' + str(splice3p) + '\t' + strand + '\t3p'
            outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][0]
            outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][1]
            outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][2]
            outline = outline + '\t' + SpliceSite5pDict[(chr,splice5p,strand)][splice3p]['annotation'][3] + '\t' + str(counts)
            outline = outline + '\t' + str(len(SpliceSite5pDict[(chr,splice5p,strand)].keys()))
            outline = outline + '\t' + str(TotalCounts - counts)
            if TotalCounts == 0:
                outline = outline + '\t' + str(0)
            else:
                outline = outline + '\t' + str(counts/TotalCounts)
            outline = outline + '\t' + AnnotationStatus1
            outline = outline + '\t' + AnnotationStatus2
            outline = outline + '\t' + str(reads_p_value)
            outline = outline + '\t' + str(splice_copies)
            outline = outline + '\t' + str(alt_copies)
            outline = outline + '\t' + str(copies_p_value)
            outline = outline + '\t' + str(catch_rate_pvalue)
            outfile.write(outline + '\n')
            
    outfile.close()

run()