##################################
#                                #
# Last modified 2017/04/26       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s inputfilename chrFieldID posField strandField radius wigfilename outputfilename [-singlebpwig] [-normalize] [-bed] [-unstranded] [-ERANGE_hts] [-narrowPeak] [-first number]' % sys.argv[0]
        print '\tInput format: <fields .. tabs> chr <tab> position <tab> strandField' 
        print '\tThis script outputs the average signal over all regions within the given radius' 
        print '\tthe wig file can be in .bz2 or .gz format' 
        print '\tif the the -bed option is used, the middle point of a bed region will be used; specifiy the posField as the left coordinate of the region' 
        print '\tif the the -narrowPeak option is used, the posFielf will be ignored and strand will be assumed to be +' 
        sys.exit(1)
    
    regionfilename = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    posFieldID = int(sys.argv[3])
    strandFieldID = int(sys.argv[4])
    radius = int(sys.argv[5])
    wigfilename = sys.argv[6]
    outfilename = sys.argv[7]

    doFirst=False
    if '-first' in sys.argv:
        firstN=int(sys.argv[sys.argv.index('-first')+1])
        doFirst=True
        print 'will only look at the first', firstN, 'locations'

    doNarrowPeak=False
    if '-narrowPeak' in sys.argv:
        doNarrowPeak=True
        print 'will treat regions as being in narrowPeak format'

    doBPWig=False
    if '-singlebpwig' in sys.argv:
        doBPWig=True

    doBed=False
    if '-bed' in sys.argv:
        print 'will treat input as bed file and center around the midpoint of reigons'
        doBed=True

    noStrand = False
    if '-unstranded' in sys.argv:
        print 'will treat all regions as + strand'
        noStrand = True

    doERANGE=False
    if '-ERANGE_hts' in sys.argv:
        doERANGE=True

    doNormalize=False
    if '-normalize' in sys.argv:
        doNormalize=True

    RegionDict={}
    ScoreDict={}
    listoflines = open(regionfilename)
    k=0
    for line in listoflines:
        if line.startswith('#'):
            continue
        fields=line.replace('\x00','').strip().split('\t')
        if doNarrowPeak:
            pass
        else:
            if len(fields) < max(chrFieldID, posFieldID, strandFieldID, 3):
                continue
        k+=1
        if doFirst and k > firstN:
            continue
        if len(fields)<3:
           continue
        if doBed:
            chr=fields[chrFieldID]
            left=int(fields[posFieldID])
            right=int(fields[posFieldID+1])
            pos=int((right+left)/2.0)
            if noStrand:
                strand='+'
            else:
                strand=fields[strandFieldID]
        elif doERANGE:
            chr=fields[1]
            pos=int(fields[9])
            strand='+'
        elif doNarrowPeak:
            chr=fields[0]
            pos=int(fields[1]) + int(fields[9])
            strand='+'
        else:
            chr=fields[chrFieldID]
            pos=int(fields[posFieldID])
            if noStrand:
                strand='+'
            else:
                strand=fields[strandFieldID]
        RegionDict[(chr,pos,strand)]=[]
        if ScoreDict.has_key(chr):
            pass
        else:
            ScoreDict[chr]={}
        for i in range(pos-radius,pos+radius):
            ScoreDict[chr][i]=0
        if k % 10000 == 0:
            print k

    print 'Finished importing regions', len(RegionDict.keys()), 'regions in total'
    print 'Importing wiggle scores'    
    
    if wigfilename.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + wigfilename
    elif wigfilename.endswith('.gz'):
        cmd = 'zcat ' + wigfilename
    else:
        cmd = 'cat ' + wigfilename
    p = os.popen(cmd, "r")
    line = 'line'
    j=0
    if doBPWig:
        while line != '':
            line = p.readline()
            if line == '':
                break
            j+=1
            if j % 10000000 == 0:
                out = str(j/1000000) + 'M lines processed'
                print out
            if line.startswith('variableStep'):
                chr=line.strip().split('chrom=')[1].split(' ')[0]
                continue
            if line.startswith('#bedGraph section'):
                chr=line.strip().split(':')[0].split(' ')[2]
                continue
            fields=line.strip().split('\t')
            pos=int(fields[0])
            score=float(fields[1])
            if ScoreDict.has_key(chr): 
                if ScoreDict[chr].has_key(pos):
                    ScoreDict[chr][pos]=score
    else:
        while line != '':
            line = p.readline()
            if line == '':
                break
            j+=1
            if j % 1000000 == 0:
                out = str(j/1000000) + 'M lines processed'
                print out
            if ' ' in line:
                fields=line.strip().split(' ')
            else:
                fields=line.strip().split('\t')
            chr=fields[0]
            try:
                start=int(fields[1])
            except:
                continue
            try:
                stop=int(fields[2])
                score=float(fields[3])
            except:
                print fields
            for i in range(start,stop):
                if ScoreDict.has_key(chr): 
                    if ScoreDict[chr].has_key(i):
                        ScoreDict[chr][i]=score

    print 'Finished importing wiggle scores'    
    print 'Outputting final stats'    

    for (chr,pos,strand) in RegionDict.keys():
        for i in range(pos-radius,pos+radius):
            RegionDict[(chr,pos,strand)].append(ScoreDict[chr][i])
        if strand=='R' or strand=='-':
            RegionDict[(chr,pos,strand)].reverse()

    keys=RegionDict.keys()
    keys.sort()

    FinalDict={}
    for i in range(0-radius,0+radius):
        FinalDict[i]=0.0
    for (chr,pos,strand) in keys:
        for i in range(-radius,+radius):
            FinalDict[i]+=RegionDict[(chr,pos,strand)][i+radius]

    outfile=open(outfilename,'w')
    
    keys=FinalDict.keys()
    keys.sort()
    for i in keys:
        outline = str(i) + '\t' + str(FinalDict[i])
        if doNormalize:
            outline = str(i) + '\t' + str(FinalDict[i]/len(RegionDict))
        outfile.write(outline + '\n')

    outfile.close()
   
run()
