##################################
#                                #
# Last modified 05/12/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf_with_class_codes isoforms.fpkm_tracking outputfilename [-class_code u | not_u] [-names GTF] [-biotypes biotype1(,biotype2,...biotypeN)]' % sys.argv[0]
        print '     The biotypes option is only active when the -names GTF option is used and it looks for the gene_type'
        sys.exit(1)
    
    GTF = sys.argv[1]
    expr = sys.argv[2]
    outfilename = sys.argv[3]

    doCC=False
    if '-class_code' in sys.argv:
        doCC=True
        CC=sys.argv[sys.argv.index('-class_code')+1]
        print 'class_code', CC

    doNames=False
    doBioTypes=False
    if '-names' in sys.argv:
        NameDict={}
        doNames=True
        NamesGTF=sys.argv[sys.argv.index('-names')+1]
        if '-biotypes' in sys.argv:
            doBioTypes=True
            BioTypePresentDict={}
            BioTypesDict={}
            biotypes=sys.argv[sys.argv.index('-biotypes')+1].split(',')
            for bt in biotypes:
                BioTypesDict[bt]=''
        linelist=open(NamesGTF)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            if fields[2]!='exon':
                continue
            geneID=fields[8].split('gene_id "')[1].split('";')[0]
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
            transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
            if doBioTypes:
                bioType=fields[8].split('gene_type "')[1].split('";')[0]
                if BioTypesDict.has_key(bioType):
                    BioTypePresentDict[transcriptID]=''
                    BioTypePresentDict[geneID]=''
                else:
                    continue
            NameDict[transcriptID]=transcriptName
            NameDict[geneID]=geneName

    outfile = open(outfilename, 'w')

    TranscriptDict={}
    
    linelist = open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if doCC:
            try:
                cc=fields[8].split('class_code "')[1].split('";')[0]
            except:
                cc = ''
            if CC == 'u' and cc != 'u':
                print transcriptID, cc
                continue 
            if CC == 'not_u' and cc != '' and cc != 'j' and cc != '=':
                continue
        if doBioTypes:
            if BioTypePresentDict.has_key(transcriptID):
                pass
            else:
                continue
        if TranscriptDict.has_key(transcriptID):
            pass
        else:
            TranscriptDict[transcriptID]={}
            TranscriptDict[transcriptID]['exons']=[]
        TranscriptDict[transcriptID]['exons'].append((chr,start,stop,strand))

    linelist = open(expr)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('tracking_id'):
            try: 
                statusField = fields.index('FPKM_status')
            except:
                statusField = fields.index('status')
            FPKMField = fields.index('FPKM')
            FPKM_conf_loField = fields.index('FPKM_conf_lo')
            FPKM_conf_hiField = fields.index('FPKM_conf_hi')
            continue
        transcriptID=fields[0]
        status=fields[statusField]
        FPKM=float(fields[FPKMField])
        FPKM_lo=float(fields[FPKM_conf_loField])
        FPKM_hi=float(fields[FPKM_conf_hiField])
        if TranscriptDict.has_key(transcriptID):
            TranscriptDict[transcriptID]['expr']=(status,FPKM,FPKM_lo,FPKM_hi)

    TSSDict={}

    outfile.write('#chr\tTSS\tstrand\ttranscripts\tFPKM\tFPKM_lo\tFPKM_hi\n')

    for transcriptID in TranscriptDict:
        if TranscriptDict[transcriptID].has_key('expr'):
            pass
        else:
            continue
        TranscriptDict[transcriptID]['exons'].sort()
        chr=TranscriptDict[transcriptID]['exons'][0][0]
        strand=TranscriptDict[transcriptID]['exons'][0][3]
        status=TranscriptDict[transcriptID]['expr'][0]
        FPKM=TranscriptDict[transcriptID]['expr'][1]
        FPKM_lo=TranscriptDict[transcriptID]['expr'][2]
        FPKM_hi=TranscriptDict[transcriptID]['expr'][3]
        if strand=='+':
            TSS=(chr,TranscriptDict[transcriptID]['exons'][0][1],strand)
        if TranscriptDict[transcriptID]['exons'][0][3]=='-':
            TSS=(chr,TranscriptDict[transcriptID]['exons'][-1][2],strand)
        if TranscriptDict[transcriptID]['exons'][0][3]=='.':
            continue
        if TSSDict.has_key(TSS):
            if TSSDict[TSS]['status'] == 'FAIL':
                pass
            elif status == 'FAIL':
                TSSDict[TSS]['status'] = 'FAIL'
            else:
                pass 
        else:
            TSSDict[TSS]={}
            TSSDict[TSS]['transcripts']=[]
            TSSDict[TSS]['status']=status
            TSSDict[TSS]['FPKM']=0
            TSSDict[TSS]['FPKM_lo']=0
            TSSDict[TSS]['FPKM_hi']=0
        TSSDict[TSS]['transcripts'].append(transcriptID)
        TSSDict[TSS]['FPKM']+=FPKM
        TSSDict[TSS]['FPKM_lo']+=FPKM_lo
        TSSDict[TSS]['FPKM_hi']+=FPKM_hi

    keys=TSSDict.keys()
    keys.sort()
    for (chr,TSS,strand) in keys:
        outline=chr+'\t'+str(TSS)+'\t'+strand+'\t'
        TSSDict[(chr,TSS,strand)]['transcripts'].sort()
        for transcriptID in TSSDict[(chr,TSS,strand)]['transcripts']:
            if doNames:
                transcriptID = NameDict[transcriptID]
            outline=outline+transcriptID+','
        outline=outline[0:-1]
        outline=outline+'\t'+ str(TSSDict[(chr,TSS,strand)]['FPKM']) + ',' + TSSDict[chr,TSS,strand]['status'] + '\t'+str(TSSDict[(chr,TSS,strand)]['FPKM_lo']) + ',' + TSSDict[chr,TSS,strand]['status'] + '\t' + str(TSSDict[(chr,TSS,strand)]['FPKM_hi'])  + ',' + TSSDict[chr,TSS,strand]['status']
        outfile.write(outline+'\n')
   
    outfile.close()
   
run()
