##################################
#                                #
# Last modified 01/02/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s juncs isoforms_FPKM junction_threshold FPKM_threshold outfilename' % sys.argv[0]
        print 'File format requirements: juncs:'
        print '#chr    left    right   strand  GeneID(s)       GeneName(s)     TranscriptID(s) TranscriptName(s) sample1 ... sampleN'
        print 'chr1    12056   12178   +       ENSG00000223972.3       DDX11L1 TCONS_00000009,TCONS_00000002   TCONS_00000009,TCONS_00000002'
        print 'File format requirements: isoforms_FPKM:'
        print 'ENST00000000233.5       -       -       ENSG00000004059.5       ARF5    -       chr7:127228398-127231759 sample1....sampleN'
        sys.exit(1)

    juncs = sys.argv[1]
    FPKM = sys.argv[2]
    junction_threshold = float(sys.argv[3])
    FPKM_threshold = float(sys.argv[4])
    outfilename = sys.argv[5]

    GeneDict={}
    JunctionsCountsDict={}

    lineslist  = open(juncs)
    for line in lineslist:
        fields=line.strip().split('\t')
        if line[0]=='#':
            LabelDict={}
            for i in range(8,len(fields)):
                LabelDict[i]=fields[i]
            continue
        chr=fields[0]
        left=int(fields[1])
        right=int(fields[2])
        strand=fields[3]
        GeneIDs = fields[4].split(',')
        GeneNames = fields[5].split(',')
        if len(GeneIDs) > 1:
            continue
        geneID = GeneIDs[0]
        geneName =  GeneNames[0]
        junction=(chr,left,right,strand)
        if GeneDict.has_key((geneID,geneName)):
            pass
        else:
            GeneDict[(geneID,geneName)]={}
            GeneDict[(geneID,geneName)]['junctions']={}
            GeneDict[(geneID,geneName)]['transcripts']={}
        transcriptIDs = fields[6].split(',')
        GeneDict[(geneID,geneName)]['junctions'][junction]=[]
        for transcriptID in transcriptIDs:
            GeneDict[(geneID,geneName)]['junctions'][junction].append(transcriptID)
            if GeneDict[(geneID,geneName)]['transcripts'].has_key(transcriptID):
                pass
            else:
                GeneDict[(geneID,geneName)]['transcripts'][transcriptID]=[]
            GeneDict[(geneID,geneName)]['transcripts'][transcriptID].append(junction)
        JunctionsCountsDict[junction]={}
        for i in range(8,len(fields)):
            label = LabelDict[i]
            JunctionsCountsDict[junction][label]=float(fields[i])

    print 'finished importing junctions information'

    IsoformsExprDict={}
    lineslist  = open(FPKM)
    for line in lineslist:
        if line.startswith('tracking_id'):
            continue
        fields=line.strip().split('\t')
        if line[0]=='#':
            LabelDict={}
            for i in range(7,len(fields)):
                LabelDict[i]=fields[i]
            continue
        transcriptID=fields[0]
        IsoformsExprDict[transcriptID]={}
        for i in range(7,len(fields)):
            label = LabelDict[i]
            IsoformsExprDict[transcriptID][label]=float(fields[i].split(',')[0])

    print 'finished importing expression values'

    outfile = open(outfilename, 'w')

    outline='#ID\tgene_or_transcript\tGeneID\tGeneName\tHas_unique_junctions\tPassing_in_all_samples\tPassing_in_at_least_one_samples\n'
    outfile.write(outline)

    skipped = 0
    for (geneID,geneName) in GeneDict.keys():
        AllHaveUniqueJunctions=True
        for transcriptID in GeneDict[(geneID,geneName)]['transcripts'].keys():
            HasUniqueJunctions=False
            for junction in GeneDict[(geneID,geneName)]['transcripts'][transcriptID]:
                if len(GeneDict[(geneID,geneName)]['junctions'][junction]) == 1:
                    HasUniqueJunctions=True
                    break
            if HasUniqueJunctions == False:
                AllHaveUniqueJunctions = False
                break
        PassOnceGene = False
        PassAllGene = True
        for transcriptID in GeneDict[(geneID,geneName)]['transcripts'].keys():
            if IsoformsExprDict.has_key(transcriptID):
                pass
            else:
                skipped+=1
                print 'not found in expression file, skipping', skipped, transcriptID
                continue
            PassOnce = False
            PassAll = False
            UniqueJunctions = []
            for junction in GeneDict[(geneID,geneName)]['transcripts'][transcriptID]:
                if len(GeneDict[(geneID,geneName)]['junctions'][junction]) == 1:
                    UniqueJunctions.append(junction)
            Pass1 = 'No'
            Pass2 = 'No'
            Pass3 = 'No'
            if len(UniqueJunctions) == 0:
                outline = transcriptID + '\ttranscript\t' + geneID + '\t' + geneName + '\tNo\tNo\tNo'
                outfile.write(outline + '\n')
                continue                
            Pass1 = 'Yes'
            labels = JunctionsCountsDict[junction].keys()
            FPKMPass = 0
            JunctionPass = 0
            l=0
            for label in labels:
                if IsoformsExprDict[transcriptID][label] >= FPKM_threshold:
                    FPKMPass+=1
                    max = 0
                    for junction in UniqueJunctions:
                        if JunctionsCountsDict[junction][label] > max:
                            max = JunctionsCountsDict[junction][label]
                    if max >= junction_threshold:
                        PassOnce=True
                        JunctionPass+=1
            if FPKMPass == 0:
                Pass2 = 'N/A'
                Pass3 = 'N/A'
            elif FPKMPass == JunctionPass:
                Pass2 = 'Yes'
                Pass3 = 'Yes'
                PassAll = True
                PassOnce = True
            elif JunctionPass == 0:
                Pass2 = 'No'
                Pass3 = 'No'
                PassAll = False
                PassOnce = False
            else:
                Pass2 = 'No'
                Pass3 = 'Yes'
                PassAll = False
                PassOnce = True
#            outline = transcriptID + '\ttranscript\t' + geneID + '\t' + geneName + '\t' + Pass1 + '\t' + Pass2 + '\t'  + Pass3 + '\t'  + str(FPKMPass)  + '\t'  + str(JunctionPass)
            outline = transcriptID + '\ttranscript\t' + geneID + '\t' + geneName + '\t' + Pass1 + '\t' + Pass2 + '\t'  + Pass3
            outfile.write(outline + '\n')
#        if AllHaveUniqueJunctions:
#        else:
#            outline = geneID + '\tgene\t' + geneID + '\t' + geneName + '\tNo\tNo\tNo'
#            outfile,write(outline + '\n')
        

    outfile.close()
        
run()

