##################################
#                                #
# Last modified 01/02/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s juncs isoforms_FPKM GTF readNumbersTable readLength minReadJunctionOverlap outfilename' % sys.argv[0]
        print 'File format requirements: juncs:'
        print '\t#chr    left    right   strand  GeneID(s)       GeneName(s)     TranscriptID(s) TranscriptName(s) sample1 ... sampleN'
        print '\tchr1    12056   12178   +       ENSG00000223972.3       DDX11L1 TCONS_00000009,TCONS_00000002   TCONS_00000009,TCONS_00000002'
        print 'File format requirements: isoforms_FPKM:'
        print '\tENST00000000233.5       -       -       ENSG00000004059.5       ARF5    -       chr7:127228398-127231759 sample1....sampleN'
        print 'Read Number Table format'
        print '\tName\tUnique\tUniqueSplices\tMulti\tMultiSplices'
        sys.exit(1)

    juncs = sys.argv[1]
    FPKM = sys.argv[2]
    gtf = sys.argv[3]
    readNumberTable = sys.argv[4]
    readLength = int(sys.argv[5])
    minReadJunctionOverlap = int(sys.argv[6])
    outfilename = sys.argv[7]

    ReadNumberDict={}
    linelist = open(readNumberTable)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        label = fields[0]
        ReadNumberDict[label] = int(fields[1]) + int(fields[2]) + int(fields[3]) + int(fields[4])

    print 'finished importing read numbers'

    TranscriptLengthDict={}
    TranscriptDict={}
    linelist = open(gtf)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict.has_key(TranscriptID):
            pass
        else:
            TranscriptDict[TranscriptID]=[]
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        TranscriptDict[TranscriptID].append((chr,left,right,orientation))

    print 'finished importing GTF'

    for TranscriptID in TranscriptDict.keys():
        TranscriptLengthDict[TranscriptID]=0
        for (chr,left,right,orientation) in TranscriptDict[TranscriptID]:
            TranscriptLengthDict[TranscriptID]+= (right-left)

    print 'finished calculating transcript lengths'

    GeneDict={}
    JunctionsCountsDict={}

    lineslist  = open(juncs)
    for line in lineslist:
        fields=line.strip().split('\t')
        if line[0]=='#':
            LabelDict={}
            for i in range(8,len(fields)):
                LabelDict[i]=fields[i]
            continue
        chr=fields[0]
        left=int(fields[1])
        right=int(fields[2])
        strand=fields[3]
        GeneIDs = fields[4].split(',')
        GeneNames = fields[5].split(',')
        if len(GeneIDs) > 1:
            continue
        geneID = GeneIDs[0]
        geneName =  GeneNames[0]
        junction=(chr,left,right,strand)
        if GeneDict.has_key((geneID,geneName)):
            pass
        else:
            GeneDict[(geneID,geneName)]={}
            GeneDict[(geneID,geneName)]['junctions']={}
            GeneDict[(geneID,geneName)]['transcripts']={}
        transcriptIDs = fields[6].split(',')
        GeneDict[(geneID,geneName)]['junctions'][junction]=[]
        for transcriptID in transcriptIDs:
            GeneDict[(geneID,geneName)]['junctions'][junction].append(transcriptID)
            if GeneDict[(geneID,geneName)]['transcripts'].has_key(transcriptID):
                pass
            else:
                GeneDict[(geneID,geneName)]['transcripts'][transcriptID]=[]
            GeneDict[(geneID,geneName)]['transcripts'][transcriptID].append(junction)
        JunctionsCountsDict[junction]={}
        for i in range(8,len(fields)):
            label = LabelDict[i]
            JunctionsCountsDict[junction][label]=float(fields[i])

    print 'finished importing junctions information'

    IsoformsExprDict={}
    lineslist  = open(FPKM)
    for line in lineslist:
        if line.startswith('tracking_id'):
            continue
        fields=line.strip().split('\t')
        if line[0]=='#':
            LabelDict={}
            for i in range(7,len(fields)):
                LabelDict[i]=fields[i]
            continue
        transcriptID=fields[0]
        IsoformsExprDict[transcriptID]={}
        for i in range(7,len(fields)):
            label = LabelDict[i]
            IsoformsExprDict[transcriptID][label]=float(fields[i].split(',')[0])

    print 'finished importing expression values'

    outfile = open(outfilename, 'w')

    outline='#ID\tgene_or_transcript\tGeneID\tGeneName\tHas_unique_junctions\tSample\tFPKM\tMaximum_Observed_Unique_Junction_Fragments\tNB_Expected_Unique_Junction_Fragments\tPoisson_Z-score\tNB_Z-score\n'
    outfile.write(outline)

    LL = 0
    skipped = 0
    for (geneID,geneName) in GeneDict.keys():
        AllHaveUniqueJunctions=True
        for transcriptID in GeneDict[(geneID,geneName)]['transcripts'].keys():
            HasUniqueJunctions=False
            for junction in GeneDict[(geneID,geneName)]['transcripts'][transcriptID]:
                if len(GeneDict[(geneID,geneName)]['junctions'][junction]) == 1:
                    HasUniqueJunctions=True
                    break
            if HasUniqueJunctions == False:
                AllHaveUniqueJunctions = False
                break
        for transcriptID in GeneDict[(geneID,geneName)]['transcripts'].keys():
            if IsoformsExprDict.has_key(transcriptID):
                pass
            else:
                skipped+=1
                continue
            UniqueJunctions = []
            for junction in GeneDict[(geneID,geneName)]['transcripts'][transcriptID]:
                if len(GeneDict[(geneID,geneName)]['junctions'][junction]) == 1:
                    UniqueJunctions.append(junction)
            for label in IsoformsExprDict[transcriptID].keys():
                LL+=1
                if LL % 10000 == 0:
                    print LL	
                FPKM = IsoformsExprDict[transcriptID][label]
                if len(UniqueJunctions) == 0:
                    outline = transcriptID + '\ttranscript\t' + geneID + '\t' + geneName + '\tNo\t' + label + '\t' + str(FPKM) + '\tN\A\tN\A\tN\A\tN\A'
                    outfile.write(outline + '\n')
                    continue                
                maxUniqueJunctionCount=0
                for junction in UniqueJunctions:
                    maxUniqueJunctionCount = max(maxUniqueJunctionCount,JunctionsCountsDict[junction][label])
                transcriptLength = TranscriptLengthDict[TranscriptID]
                ReadNumber = ReadNumberDict[label]
                JunctionSpan = 2*(readLength - minReadJunctionOverlap)
                fragments = (ReadNumber/2000000)*(transcriptLength/1000)*FPKM
                p = JunctionSpan/(transcriptLength+0.0)
                ExpectedJunctionFragments = p*fragments
                PoissonSTD = math.sqrt(ExpectedJunctionFragments)
                NBSTD = math.sqrt(ExpectedJunctionFragments/p)
                if ExpectedJunctionFragments == 0:
                    if maxUniqueJunctionCount > 0:
                        PoissionZScore = 'Inf'
                        NBZScore = 'Inf'
                    else:
                        PoissionZScore = 0
                        NBZScore = 0
                else:
                    PoissonZScore = (maxUniqueJunctionCount - ExpectedJunctionFragments)/PoissonSTD
                    NBZScore = (maxUniqueJunctionCount - ExpectedJunctionFragments)/NBSTD
                outline = transcriptID + '\ttranscript\t' + geneID + '\t' + geneName + '\tYes\t' + label + '\t' + str(FPKM) + '\t' + str(maxUniqueJunctionCount) + '\t' + str(ExpectedJunctionFragments) + '\t' + str(PoissonZScore) + '\t' + str(NBZScore)
                outfile.write(outline + '\n')

    print skipped

    outfile.close()
        
run()

