##################################
#                                #
# Last modified 2025/05/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import gzip
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s bedfilename chrField wigfilename readlength outputfilename [-gtfexons] [-singlebpwig] [-chr chrN] [-RPMInput]' % sys.argv[0]
        print 'Note: use the -gtf option to run the scirpt directly on the exons of a gtf file instead of a bed' 
        print 'Note: the script assumes non-normalized input, unless the -RPMInput is specified, in which case the wiggle track scores will be treated as RPMs' 
        sys.exit(1)
    
    bed = sys.argv[1]
    fieldID = int(sys.argv[2])
    wig = sys.argv[3]
    readlength = int(sys.argv[4])
    outfilename = sys.argv[5]

    doBPWig=False
    if '-singlebpwig' in sys.argv:
        doBPWig=True

    doRPMInput = False
    if '-RPMInput' in sys.argv:
        doRPMInput = True
        print 'will treat input values as RPMs'

    doGTF=False
    doChr=False

    if '-gtfexons' in sys.argv:
        doGTF=True
        print 'will treat input file as a gtf'

    if '-chr' in sys.argv:
        doChr=True
        TheChr=sys.argv[sys.argv.index('-chr')+1]
        print 'will only look at', TheChr

    regionDict={}
    coverageDict={}
    
    lineslist = open(bed)
    l=0
    if doGTF:
        exonList=[]
        for line in lineslist:
            l+=1
            if l % 100000 == 0:
                print l, 'lines processed'
            if line[0]=='#':
                continue
            fields=line.strip().split('\t')
            chr=fields[0]
            if doChr and chr != TheChr:
                continue
            left=int(fields[3])
            right=int(fields[4])
            exonList.append((chr,left,right))
        print 'total exons found', len(exonList)
        exonList=list(Set(exonList))
        exonList.sort()
        print 'collapsed exons set', len(exonList)
        for (chr,left,right) in exonList:
            if regionDict.has_key(chr):
                pass
            else:
                print chr
                regionDict[chr]={}
                coverageDict[chr]={}
            regionDict[chr][(chr,left,right)]={}
            for i in range(left,right):
                if coverageDict[chr].has_key(i):
                    pass
                else:
                    coverageDict[chr][i]=0
    else:
        for line in lineslist:
            l+=1
            if l % 50000 == 0:
                print l, 'lines processed'
            if line[0]=='#':
                continue
            fields=line.strip().split('\t')
            chr=fields[fieldID]
            if doChr and chr != TheChr:
                continue
            left=int(fields[fieldID+1])
            right=int(fields[fieldID+2])
            if regionDict.has_key(chr):
                pass
            else:
                regionDict[chr]={}
                coverageDict[chr]={}
            regionDict[chr][(chr,left,right)]=''
            for i in range(left,right):
                coverageDict[chr][i]=0

    print regionDict.keys()
    print coverageDict.keys()

    print 'finished inputing bed file, processing wig file'

    if wig.endswith('gz'):
        linelist = gzip.open(wig)
    else:
        linelist = open(wig)
    i=0
    TotalScore=0
    if doBPWig:
        for line in linelist:
            if line[0]=='#':
                continue
            i+=1
            if i % 10000000 == 0:
                print i, 'lines processed'  
            if line.startswith('variableStep'):
                chr=line.strip().split('chrom=')[1].split(' ')[0]
                continue
            fields=line.strip().split('\t')
            pos=int(fields[0])
            score=float(fields[1])
            TotalScore+=score
            if doChr and chr != TheChr:
                continue
            if coverageDict.has_key(chr) and coverageDict[chr].has_key(pos):
                    coverageDict[chr][pos]=score
    else:
        for line in linelist:
            i+=1
            if i % 10000000 == 0:
                print i, 'lines processed'  
            if line.startswith('#') or line.startswith('track'):
                continue
            fields=line.strip().split('\t')
            chr=fields[0]
            left=int(fields[1])
            right=int(fields[2])
            score=float(fields[3])
            for j in range(left,right):
                TotalScore+=score
            if doChr and chr != TheChr:
                continue
            for j in range(left,right):
                if coverageDict.has_key(chr) and coverageDict[chr].has_key(j):
                    coverageDict[chr][j]=score

    outfile = open(outfilename, 'w')

    chrkeys=regionDict.keys()
    chrkeys.sort()

    ReadNumber=TotalScore/readlength

    print 'estimated read number based on total sum of scores:', ReadNumber

    outfile.write('#chr\tleft\tright\tRPKM\n')

    for chr in chrkeys:
        keys=regionDict[chr].keys()
        keys.sort()
        for (chr,left,right) in keys:
            outline=chr+'\t'+str(left)+'\t'+str(right)
            if left==right:
                outline=outline+'\t0'
                outfile.write(outline+'\n')
                continue
            score=0.0
            for i in range(left,right):
                score+=coverageDict[chr][i]
            if doRPMInput:
                RPKM=(score)/((right-left)/1000.0)
            else:
                RPKM=(score)/(((right-left)/1000.0)*(TotalScore/1000000.0))
            outline=outline+'\t' + str(RPKM)
            outfile.write(outline+'\n')
          
    outfile.close()
   
run()
