##################################
#                                #
# Last modified 2017/02/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s regions chrFieldID phastConsWig minScores(comma-separated) radius window outputfilename [-InputScoreFile outfilename scoreFieldID]' % sys.argv[0]
        print '\tNote: use - for stdIn for the wig file'
        print '\tNote: the script will center the region on the middle of the section flanked by nucleotides with the highest conservation score among the minScore list'
        sys.exit(1)
    
    regions = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    input = sys.argv[3]
    minScores = []
    for m in sys.argv[4].split(','):
        minScores.append(float(m))
    minScores.sort()
    minScores.reverse()
    radius = int(sys.argv[5])
    window = int(sys.argv[6])
    outfilename = sys.argv[7]

    doISF = False
    if '-InputScoreFile' in sys.argv:
        doISF = True
        ISFoutfile =  open(sys.argv[sys.argv.index('-InputScoreFile') + 1],'w')
        outline = '#ID'
        for i in range(0-radius,radius,window):
            outline = outline + '\t' + str(i)
        ISFoutfile.write(outline + '\n')
        ISFFeldID = int(sys.argv[sys.argv.index('-InputScoreFile') + 2])

    RegionsCoverageDict = {}
    linelist = open(regions)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[chrFieldID + 1])
        right = int(fields[chrFieldID + 2])
        if RegionsCoverageDict.has_key(chr):
            pass
        else:
            RegionsCoverageDict[chr] = {}
        for i in range(left-radius,right+radius):
            RegionsCoverageDict[chr][i] = 0

    print 'finished parsing regions'

    if input == '-':
        lineslist  = sys.stdin
    else:
        lineslist = open(input)
    InRegion = False
    k = 0
    for line in lineslist:
        k+=1
        if k % 10000000 == 0:
            print str(k/1000000) + 'M lines processed'
        if line.startswith('fixedStep'):
            chr=line.split(' ')[1].split('=')[1]
            currentPos=int(line.split(' ')[2].split('=')[1])
            continue
        score = float(line.split('\n')[0])
        if RegionsCoverageDict.has_key(chr) and RegionsCoverageDict[chr].has_key(currentPos):
            RegionsCoverageDict[chr][currentPos] = score
        currentPos+=1

    print 'finished parsing track'

    outfile = open(outfilename, 'w')
    outline = 'ID'
    for i in range(0-radius,radius,window):
        outline = outline + '\t' + str(i)
    outfile.write(outline + '\n')

    linelist = open(regions)
    for line in linelist:
        if line.startswith('#'):
            outline = line.strip() + '\tconservation_score\tfraction_conserved'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[chrFieldID + 1])
        right = int(fields[chrFieldID + 2])
        ConsScoreVector = []
        for i in range(left,right):
            ConsScoreVector.append(RegionsCoverageDict[chr][i])
        MS = 0
        for minScore in minScores:
            if max(ConsScoreVector) >= minScore:
                MS = minScore
                break
        MSleft = left
        MSright = right
        for i in range(left,right):
            if RegionsCoverageDict[chr][i] >= MS:
                MSleft = i
                break
        for i in range(left,right):
            if RegionsCoverageDict[chr][left + (right-i)] >= MS:
                MSright = left + (right-i)
                break
        middle = int((MSright+MSleft)/2.)
        outline = chr + ':' + str(left) + '-' + str(right)
        for i in range(middle-radius,middle+radius,window):
            scoreList = []
            for j in range(i,i+window):
                scoreList.append(RegionsCoverageDict[chr][j])
            score = sum(scoreList)/window
            outline = outline + '\t' + str(score)
        outfile.write(outline + '\n')
        regionradius = (right - left)/2.
        if doISF:
            ISF = float(fields[ISFFeldID])
            outline = chr + ':' + str(left) + '-' + str(right)
            for i in range(middle-radius,middle+radius,window):
                if i > (middle-regionradius) and i <= (right-regionradius):
                    outline = outline + '\t' + str(ISF)
                else:
                    outline = outline + '\t' + str(0)
            ISFoutfile.write(outline + '\n')

    outfile.close()
   
run()
