##################################
#                                #
# Last modified 2017/06/29       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
from sets import Set

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s inputfilename chrFieldID posField strandField radius pluswigfilename minuswigfilename outputfilename [-normalize] [-variableStep] [-plusOnly] [-minusOnly]' % sys.argv[0]
        print '\tInput format: <fields .. tabs> chr <tab> position <tab> strandField' 
        print '\tThis script outputs the average signal over all regions within the given radius' 
        print '\tThe input and the wig file can be zipped' 
        sys.exit(1)
    
    regionfilename = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    posFieldID = int(sys.argv[3])
    strandFieldID = int(sys.argv[4])
    radius = int(sys.argv[5])
    pluswigfilename = sys.argv[6]
    minuswigfilename = sys.argv[7]
    outfilename = sys.argv[8]

    doVariableStep=False
    if '-variableStep' in sys.argv:
        doVariableStep=True

    print outfilename

    doNormalize=False
    if '-normalize' in sys.argv:
        doNormalize=True

    doPlusOnly=False
    if '-plusOnly' in sys.argv:
        doPlusOnly=True
        print 'will only look at regions in the plus orientation'

    doMinusOnly=False
    if '-minusOnly' in sys.argv:
        doMinusOnly=True
        print 'will only look at regions in the minus orientation'

    RegionDict={}
    ScoreDict={}
    ScoreDict['plus']={}
    ScoreDict['minus']={}

    if regionfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regionfilename
    elif regionfilename.endswith('.gz'):
        cmd = 'zcat ' + regionfilename
    else:
        cmd = 'cat ' + regionfilename
    p = os.popen(cmd, "r")
    line = 'line'
#    listoflines = open(regionfilename)
    k=0
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        k+=1
        if len(fields)<3:
           continue
        chr=fields[chrFieldID]
        pos=int(fields[posFieldID])
        strand=fields[strandFieldID]
        if doPlusOnly and strand == '-':
            continue
        if doMinusOnly and strand == '+':
            continue
        RegionDict[(chr,pos,strand)]=[]
        if ScoreDict['plus'].has_key(chr):
            pass
        else:
            ScoreDict['plus'][chr]={}
            ScoreDict['minus'][chr]={}
        for i in range(pos-radius,pos+radius):
            ScoreDict['plus'][chr][i]=0
            ScoreDict['minus'][chr][i]=0
        if k % 10000 == 0:
            print k

    print 'Finished importing regions'    
    print 'Importing wiggle scores'    

    if pluswigfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + pluswigfilename
    elif pluswigfilename.endswith('.gz'):
        cmd = 'zcat ' + pluswigfilename
    else:
        cmd = 'cat ' + pluswigfilename
    p = os.popen(cmd, "r")
    line = 'line'
#    listoflines = open(pluswigfilename)
    j=0
    if doVariableStep:
        while line != '':
            line = p.readline()
            if line == '':
                break
            j+=1
            if j % 1000000 == 0:
                out = str(j/1000000) + 'M lines processed'
                print out
            if line.startswith('variableStep'):
                chr=line.strip().split('chrom=')[1].split(' ')[0]
                span=int(line.strip().split('span=')[1].split(' ')[0])
                continue
            else:
                fields=line.strip().split('\t')
            start=int(fields[0])
            stop=start+span
            score=float(fields[1])
            for i in range(start,stop):
                if ScoreDict['plus'].has_key(chr): 
                    if ScoreDict['plus'][chr].has_key(i):
                        ScoreDict['plus'][chr][i]=score
    else:
        while line != '':
            line = p.readline()
            if line == '':
                break
            j+=1
            if j % 1000000 == 0:
                out = str(j/1000000) + 'M lines processed'
                print out
            if ' ' in line:
                fields=line.strip().split(' ')
            else:
                fields=line.strip().split('\t')
            chr=fields[0]
            try:
                start=int(fields[1])
            except:
                continue
            stop=int(fields[2])
            score=float(fields[3])
            for i in range(start,stop):
                if ScoreDict['plus'].has_key(chr): 
                    if ScoreDict['plus'][chr].has_key(i):
                        ScoreDict['plus'][chr][i]=score

    if minuswigfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + minuswigfilename
    elif minuswigfilename.endswith('.gz'):
        cmd = 'zcat ' + minuswigfilename
    else:
        cmd = 'cat ' + minuswigfilename
    p = os.popen(cmd, "r")
    line = 'line'
#    listoflines = open(minuswigfilename)
    j=0
    if doVariableStep:
        while line != '':
            line = p.readline()
            if line == '':
                break
            j+=1
            if j % 1000000 == 0:
                out = str(j/1000000) + 'M lines processed'
                print out
            if line.startswith('variableStep'):
                chr=line.strip().split('chrom=')[1].split(' ')[0]
                span=int(line.strip().split('span=')[1].split(' ')[0])
                continue
            else:
                fields=line.strip().split('\t')
            start=int(fields[0])
            stop=start+span
            score=float(fields[1])
            for i in range(start,stop):
                if ScoreDict['minus'].has_key(chr): 
                    if ScoreDict['minus'][chr].has_key(i):
                        ScoreDict['minus'][chr][i]=score
    else:
        while line != '':
            line = p.readline()
            if line == '':
                break
            j+=1
            if j % 1000000 == 0:
                out = str(j/1000000) + 'M lines processed'
                print out
            if ' ' in line:
                fields=line.strip().split(' ')
            else:
                fields=line.strip().split('\t')
            chr=fields[0]
            try:
                start=int(fields[1])
            except:
                continue
            stop=int(fields[2])
            score=float(fields[3])
            score=math.fabs(score)
            for i in range(start,stop):
                if ScoreDict['minus'].has_key(chr): 
                    if ScoreDict['minus'][chr].has_key(i):
                        ScoreDict['minus'][chr][i]=score

    print 'Finished importing wiggle scores'    
    print 'Outputting final stats'    

    RegionDictSense={}
    RegionDictAntiSense={}

    for (chr,pos,strand) in RegionDict.keys():
        RegionDictSense[(chr,pos,strand)]=[]
        RegionDictAntiSense[(chr,pos,strand)]=[]
        if strand == '+':
            for i in range(pos-radius,pos+radius):
                RegionDictSense[(chr,pos,strand)].append(ScoreDict['plus'][chr][i])
                RegionDictAntiSense[(chr,pos,strand)].append(ScoreDict['minus'][chr][i])
        if strand == '-':
            for i in range(pos-radius,pos+radius):
                RegionDictSense[(chr,pos,strand)].append(ScoreDict['minus'][chr][i])
                RegionDictAntiSense[(chr,pos,strand)].append(ScoreDict['plus'][chr][i])
            RegionDictSense[(chr,pos,strand)].reverse()
            RegionDictAntiSense[(chr,pos,strand)].reverse()

    keys=RegionDictSense.keys()
    keys.sort()

    FinalDictSense={}
    FinalDictAntiSense={}
    for i in range(0-radius,0+radius):
        FinalDictSense[i]=0.0
        FinalDictAntiSense[i]=0.0
    for (chr,pos,strand) in keys:
        for i in range(-radius,+radius):
            FinalDictSense[i]+=RegionDictSense[(chr,pos,strand)][i+radius]
            FinalDictAntiSense[i]+=RegionDictAntiSense[(chr,pos,strand)][i+radius]

    outfile=open(outfilename,'w')
    outfile.write('#Pos\tSense\tAntiSense\n') 
    keys=FinalDictSense.keys()
    keys.sort()
    for i in keys:
        outline = str(i) + '\t' + str(FinalDictSense[i]) + '\t' + str(-FinalDictAntiSense[i])
        if doNormalize:
            outline = str(i) + '\t' + str(FinalDictSense[i]/len(RegionDict.keys())) + '\t' + str(-FinalDictAntiSense[i]/len(RegionDict.keys()))
        outfile.write(outline + '\n')

    outfile.close()
   
run()
