##################################
#                                #
# Last modified 10/31/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s gtf outputfilename' % sys.argv[0]
        sys.exit(1)
    
    GTF = sys.argv[1]
    outfilename = sys.argv[2]

    doFPKM_lo=False
    if '-FPKM_lo' in sys.argv:
        doFPKM_lo=True
        print 'will output conf_lo estimates'

    outfile = open(outfilename, 'w')

    ExonDict={}
    
    linelist = open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        FPKM_lo=float(fields[8].split('conf_lo "')[1].split('";')[0])
        FPKM=float(fields[8].split('FPKM "')[1].split('";')[0])
        FPKM_hi=float(fields[8].split('conf_hi "')[1].split('";')[0])
        if ExonDict.has_key((chr,start,stop,strand)):
            pass
        else:
            ExonDict[(chr,start,stop,strand)]=[]
        ExonDict[(chr,start,stop,strand)].append((geneID,transcriptID,FPKM,FPKM_lo,FPKM_hi))

    keys=ExonDict.keys()
    keys.sort()

    print 'found', len(keys), 'exons'

    outfile.write('#chr\tleft\tright\tstrand\tgeneID\ttranscriptID\tFPKM\tFPKM_lo\tFPKM_hi\n')

    for (chr,start,stop,strand) in keys:
        outline=chr+'\t'+str(start)+'\t'+str(stop)+'\t'+strand + '\t'
        FinalFPKM=0.0
        FinalFPKM_lo=0.0
        FinalFPKM_hi=0.0
        genes=[]
        transcripts=[]
        for (geneID,transcriptID,FPKM,FPKM_lo,FPKM_hi) in ExonDict[(chr,start,stop,strand)]:
            FinalFPKM+=FPKM
            FinalFPKM_lo+=FPKM_lo
            FinalFPKM_hi+=FPKM_hi
            genes.append(geneID)
            transcripts.append(transcriptID)
        genes = list(Set(genes))
        genes.sort()
        transcripts = list(Set(transcripts))
        transcripts.sort()
        for gene in genes:
            outline = outline + gene + ','
        outline = outline[0:-1] + '\t'
        for transcript in transcripts:
            outline = outline + transcript + ','
        outline = outline[0:-1]
        outline=outline+'\t'+str(FinalFPKM)+'\t'+str(FinalFPKM_lo)+'\t'+str(FinalFPKM_hi)
        outfile.write(outline+'\n')
   
    outfile.close()
   
run()
