##################################
#                                #
# Last modified 06/01/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s gtf_filename wig_filename readlength output_filename [-RPM]' % sys.argv[0]

        sys.exit(1)
    
    gtf = sys.argv[1]
    wig = sys.argv[2]
    readlength = int(sys.argv[3])
    outfilename = sys.argv[4]

    doRPM=False
    if '-RPM' in sys.argv:
        doRPM=True

    coverageDict={}

    ExonDict={}
    lineslist=open(gtf)
    for line in lineslist:
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        start=int(fields[3])
        stop=int(fields[4])
        chr=fields[0]
        if coverageDict.has_key(chr):
            pass
        else:
            coverageDict[chr]={}
            print chr
        if ExonDict.has_key((chr,start,stop)):
            ExonDict[(chr,start,stop)]['transcripts'].append(transcriptID)
        else:
            ExonDict[(chr,start,stop)]={}
            ExonDict[(chr,start,stop)]['transcripts']=[]
            ExonDict[(chr,start,stop)]['transcripts'].append(transcriptID)
        for j in range(start,stop):
            coverageDict[chr][j]=0

    print 'Finished inputting annotation, found', len(ExonDict.keys()), 'exons'

    lineslist=open(wig)
    CoverageSum=0.0
    seenChr={}
    for line in lineslist:
        if line.startswith('track'):
            continue
        fields=line.strip().split('\t')
        try:
            score=float(fields[3])
        except:
            print 'skipping, ', line 
            continue
        if score==0:
            continue
        chr=fields[0]
        start=int(fields[1])
        stop=int(fields[2])
        if seenChr.has_key(chr):
            pass
        else:
            seenChr[chr]=''
            print chr
#        if coverageDict.has_key(chr):
#            pass
#        else:
#            coverageDict[chr]={}
        for j in range(start,stop):
            coverageDict[chr][j]=score
            CoverageSum+=score

    ReadNumber=CoverageSum/readlength

    NormalizeBy=ReadNumber/1000000

    print 'Estimated read number = ', ReadNumber

    outfile = open(outfilename, 'w')
    exons=ExonDict.keys()
    exons.sort()
    if doRPM:
        outfile.write('#chr\tstart\tstop\tRPM\tTranscripts\n')
    else:
        outfile.write('#chr\tstart\tstop\tRPKM\tTranscripts\n')
    for (chr,start,stop) in exons:
        score=0
        for i in range(start,stop):
            score+=(coverageDict[chr][i]/NormalizeBy)
        RPM=score/readlength
        if start==stop:
            continue
        RPKM=RPM/((stop-start)/1000.)
        if doRPM:
            strRPM=str(RPM)
            strRPM=strRPM.split('.')[0]+'.'+strRPM.split('.')[1][0:3]
            outline=chr+'\t'+str(start)+'\t'+str(stop)+'\t'+strRPM+'\t'
        else:
            strRPKM=str(RPKM)
            strRPKM=strRPKM.split('.')[0]+'.'+strRPKM.split('.')[1][0:3]
            outline=chr+'\t'+str(start)+'\t'+str(stop)+'\t'+strRPKM+'\t'
        for transcript in ExonDict[(chr,start,stop)]['transcripts']:
            outline=outline+transcript+','
        outfile.write(outline[0:-1] + '\n')

    outfile.close()
   
run()
