##################################
#                                #
# Last modified 07/30/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s regionfilename chrFieldID PhastConsDirectory outputfilename' % sys.argv[0]

        sys.exit(1)
    
    bedfile = sys.argv[1]
    fieldID=int(sys.argv[2])
    PhastConsDirectory = sys.argv[3]
    outfilename = sys.argv[4]

    regionDict={}
    lineslist = open(bedfile)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[fieldID]
        start=int(fields[fieldID+1])
        stop=int(fields[fieldID+2])
        if regionDict.has_key(chr):
            regionDict[chr][(start,stop)] = 0
            regionDict[chr][(start,stop)]
        else:
            regionDict[chr]={}
            regionDict[chr][(start,stop)] = 0

    files = os.listdir(PhastConsDirectory)
    for chr in regionDict.keys():
        print 'processing', chr
        for phastConsfile in files:
            if phastConsfile.split('.')[0]==chr:
                print 'opened file', phastConsfile
                ConsFile = PhastConsDirectory+'/'+phastConsfile
        phastConsDict={}
        for (start,stop) in regionDict[chr]:
            for i in range(start,stop):
                phastConsDict[i]= 0 
        if ConsFile.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + ConsFile
        elif ConsFile.endswith('.gz'):
            cmd = 'gunzip -c ' + ConsFile
        else:
            cmd = 'cat ' + ConsFile
        p = os.popen(cmd, "r")
        while line != '':
            line = p.readline()
            if line == '':
                break
            if line[0]=='f':
                currentPos=int(line.split('start=')[1].split(' ')[0])
            else:
                if phastConsDict.has_key(currentPos):
                    phastConsDict[currentPos] = float(line.strip())
                currentPos=currentPos+1
        for (start,stop) in regionDict[chr]:
            Score = 0
            for i in range(start,stop):
                Score += phastConsDict[i]
            regionDict[chr][(start,stop)] = Score/(stop-start)
            
    outfile = open(outfilename, 'w')

    lineslist = open(bedfile)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[fieldID]
        start=int(fields[fieldID+1])
        stop=int(fields[fieldID+2])
        outline = line.strip() + '\t' + str(regionDict[chr][(start,stop)])
        outfile.write(outline + '\n')

    outfile.close()
   
run()
