##################################
#                                #
# Last modified 2017/08/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s inputfilename chrFieldID posField [strandField | -noStrand] upstream downstream wigfilename outputfilename [-strand +|-] [-average bp] [-window bp] [-sortby fieldID] [-fullRegionInfo] [-narrowPeak]' % sys.argv[0]
        print '\tInput format: <fields .. tabs> chr <tab> position' 
        print '\tthe wig file can be in .bz2 or .gz format' 
        sys.exit(1)
    
    regionfilename = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    posFieldID = int(sys.argv[3])
    noStrand=False
    if sys.argv[4]=='-noStrand':
        noStrand=True
    else:
        strandFieldID = int(sys.argv[4])
    upstream = int(sys.argv[5])
    downstream = int(sys.argv[6])
    wigfilename = sys.argv[7]
    outfilename = sys.argv[8]

    doNP = False
    if '-narrowPeak' in sys.argv:
        doNP = True

    doStrand = False
    if '-strand' in sys.argv:
        doStrand = True
        WantedStrand = sys.argv[sys.argv.index('-strand') + 1]

    doSort=False
    if '-sortby' in sys.argv:
        doSort=True
        sortFieldID=int(sys.argv[sys.argv.index('-sortby')+1])
        sortList=[]

    doFullRegionInfo = False
    if '-fullRegionInfo' in sys.argv:
        doFullRegionInfo = True

    window=1
    averageRadius=0
    doAverage=False
    if '-average' in sys.argv:
        doAverage=True
        averageRadius=int(int(sys.argv[sys.argv.index('-average')+1])/2.0)
        print 'will average signal over', 2*averageRadius, 'bp'

    doWindow=False
    if '-window' in sys.argv:
        doWindow=True
        window=int(sys.argv[sys.argv.index('-window')+1])
        print 'will split into windows of size', window, 'bp'

    RegionDict={}
    ScoreDict={}
    header=''
    if regionfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regionfilename
    elif regionfilename.endswith('.gz'):
        cmd = 'zcat ' + regionfilename
    else:
        cmd = 'cat ' + regionfilename
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
#    listoflines = open(regionfilename)
#    for line in listoflines:
        if line.startswith('#'):
            header=line.strip()
            continue
        fields=line.strip().split('\t')
        if len(fields)<3:
            continue
        chr=fields[chrFieldID]
        if doNP:
            pos = int(fields[1]) + int(fields[9])
        else:
            pos = int(fields[posFieldID])
        if noStrand:
            strand = '+'
        else:
            strand=fields[strandFieldID]
        if doStrand and strand != WantedStrand:
            continue
        RegionDict[(chr,pos,strand)]={}
        RegionDict[(chr,pos,strand)]['line']=line.strip()
        if ScoreDict.has_key(chr):
            pass
        else:
            ScoreDict[chr]={}
        if strand=='+' or strand=='F':
            for i in range(pos-upstream-averageRadius,pos+downstream+averageRadius):
                ScoreDict[chr][i]=0
        if strand=='-' or strand=='R':
            for i in range(pos-downstream-averageRadius,pos+upstream+averageRadius):
                ScoreDict[chr][i]=0
        if doSort:
            sortList.append((float(fields[sortFieldID]),chr,pos,strand))

    print 'Finished importing regions'    
    print 'Importing wiggle scores'    
    
    j=0
    if wigfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + wigfilename
    elif wigfilename.endswith('.gz'):
        cmd = 'zcat ' + wigfilename
    else:
        cmd = 'cat ' + wigfilename
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        j+=1
        if j % 1000000 == 0:
            print j, 'lines processed'
        if ' ' in line:
            fields=line.strip().split(' ')
        else:
            fields=line.strip().split('\t')
        chr=fields[0]
        try:
            start=int(fields[1])
        except:
            continue
        stop=int(fields[2])
        score=float(fields[3])
        for i in range(start,stop):
            if ScoreDict.has_key(chr): 
                if ScoreDict[chr].has_key(i):
                    ScoreDict[chr][i]=score
    
    print 'Finished importing wiggle scores'    
    print 'Outputting final stats'    

    outfile=open(outfilename,'w')
    if doFullRegionInfo:
        outline='#chr\tleft\tright\tstrand'
    else:
        outline='#'
    for i in range(0-upstream,0+downstream+1,window):
        outline=outline+'\t'+str(i)
    outfile.write(outline+'\n')

    if doSort:
        sortList.sort()
        sortList.reverse()
        keys=[]
        for (sortScore,chr,pos,strand) in sortList:
            keys.append((chr,pos,strand))
    else:
        keys=RegionDict.keys()
        keys.sort()

    for (chr,pos,strand) in keys:
        FinalDict={}
        if doWindow:
            for i in range(0-upstream,0+downstream+1,window):
                FinalDict[i]=0
        else:
            for i in range(0-upstream,0+downstream+1):
                FinalDict[i]=0
        if doAverage:
            if strand=='+':
                for i in range(pos-upstream,pos+downstream):
                    sum=0
                    for j in range(i-averageRadius,i+averageRadius):
                        sum+=ScoreDict[chr][j]
                    score=sum/(2.0*averageRadius)
                    FinalDict[i-pos]+=score
            if strand=='-':
                for i in range(pos-downstream,pos+upstream):
                    sum=0
                    for j in range(i-averageRadius,i+averageRadius):
                        sum+=ScoreDict[chr][j]
                    score=sum/(2.0*averageRadius)
                    FinalDict[-i+pos]+=score
        elif doWindow:
            if strand=='+':
                for i in range(pos-upstream,pos+downstream,window):
                    sum=0.0
                    for j in range(i,i+window):
                        sum+=ScoreDict[chr][j]
                    score=sum/window
                    FinalDict[i-pos]+=score
            if strand=='-':
                for i in range(pos-downstream,pos+upstream,window):
                    sum=0.0
                    for j in range(i,i+window):
                        sum+=ScoreDict[chr][j]
                    score=sum/window
                    FinalDict[-i+pos]+=score
        else:
            if strand=='+':
                for i in range(pos-upstream,pos+downstream):
                    score = ScoreDict[chr][i]
                    FinalDict[i-pos]+=score
            if strand=='-':
                for i in range(pos-downstream,pos+upstream):
                    score=ScoreDict[chr][i]
                    FinalDict[-i+pos]+=score
        if doFullRegionInfo:
            outline=chr+'\t'+str(pos-upstream)+'\t'+str(pos+downstream)+'\t'+strand
        else:
            outline=chr+'|'+str(pos)+'|'+strand
        keys=FinalDict.keys()
        keys.sort()
        for i in keys:
            outline = outline+'\t'+str(FinalDict[i])
        outfile.write(outline + '\n')

    outfile.close()
   
run()
