##################################
#                                #
# Last modified 09/01/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from commoncode import *

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s bedfilename chrFieldID wigfilename readlegnth outputfilename [-rawSignal]' % sys.argv[0]
        print '	use the -rawSignal option if the input wig file is in absolute read counts and not RPM' 
        sys.exit(1)
    
    bed = sys.argv[1]
    chrField = int(sys.argv[2])
    wig = sys.argv[3]
    readlength = int(sys.argv[4])
    outfilename = sys.argv[5]

    doRaw=False
    if '-rawSignal' in sys.argv:
        doRaw=True
        print 'will assume signal is not RPM-normalized'

    coverageDict={}

    lineslist = open(bed)
    for line in lineslist:
        if line[0]=='#':
            continue
        if line.startswith('track'):
            continue
        fields=line.strip().split('\t')
        chr=fields[chrField]
        start=int(fields[chrField+1])
        stop=int(fields[chrField+2])
        if coverageDict.has_key(chr):
            pass
        else:
            coverageDict[chr]={}
        for j in range(start,stop):
            coverageDict[chr][j]=0

    print 'finished inputting annotation'
    print 'processing wig'

    lineslist = open(wig)
    i=0
    TotalScore=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        if line.startswith('track'):
            continue
        fields=line.strip().split('\t')        
        if len(fields)<4:
            continue
        chr=fields[0]
        start=int(fields[1])
        stop=int(fields[2])
        score=float(fields[3])
        for j in range(start,stop):
            TotalScore+=score
            if coverageDict.has_key(chr) and coverageDict[chr].has_key(j):
                coverageDict[chr][j]+=score
            
    readNumber=TotalScore/readlength
    normalizeBy=readNumber/1000000

    outfile = open(outfilename, 'w')
    outfile.write('#chr\tstart\tstop\tRPKM\n')

    lineslist = open(bed)
    for line in lineslist:
        if line[0]=='#':
            continue
        if line.startswith('track'):
            continue
        fields=line.strip().split('\t')
        chr=fields[chrField]
        start=int(fields[chrField+1])
        stop=int(fields[chrField+2])
        score=0
        for j in range(start,stop):
            score+=coverageDict[chr][j]
        if doRaw:
            score=score/normalizeBy
        score=score/readlength
        try:
            RPKM=score/((stop-start)/1000.0)
        except:
            RPKM=0
        outline=line.strip()+'\t'+str(RPKM)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
