##################################
#                                #
# Last modified 2018/02/14       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy as np
import os
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s bedfilename chrField bismark.cov.gz minCov outputfilename' % sys.argv[0]
        print '\tNote: this script will output the average methylation level for nucleotides with sufficient coverage'
        sys.exit(1)
    
    bed = sys.argv[1]
    fieldID = int(sys.argv[2])
    input = sys.argv[3]
    minCov = int(sys.argv[4])
    outfilename = sys.argv[5]

    regionDict={}
    coverageDict={}
    
    lineslist = open(bed)
    l=0
    for line in lineslist:
        l+=1
        if l % 50000 == 0:
            print l, 'lines processed'
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        chr = fields[fieldID]
        left = int(fields[fieldID+1])
        right = int(fields[fieldID+2])
        if regionDict.has_key(chr):
            pass
        else:
            regionDict[chr] = {}
            coverageDict[chr] = {}
        regionDict[chr][(chr,left,right)] = ''
        for i in range(left,right):
            coverageDict[chr][i] = ''

#    print regionDict.keys()
#    print coverageDict.keys()

    print 'finished inputing bed file, processing wig file'

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('.gz'):
        cmd = 'gunzip -c ' + input
    elif input.endswith('.zip'):
        cmd = 'unzip -p ' + input
    else:
        cmd = 'cat ' + input
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        i+=1
        if i % 10000000 == 0:
            print i, 'lines processed'  
        if line.startswith('#') or line.startswith('track'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        pos = int(fields[1])
        score = float(fields[3])
        C = int(fields[4]) + int(fields[5])
        if C < minCov:
            continue
        if coverageDict.has_key(chr) and coverageDict[chr].has_key(pos):
            coverageDict[chr][pos] = score

    outfile = open(outfilename, 'w')

    chrkeys=regionDict.keys()
    chrkeys.sort()

    outfile.write('#chr\tleft\tright\tAverageMethylation%\tNumPos\n')

    for chr in chrkeys:
        keys = regionDict[chr].keys()
        keys.sort()
        for (chr,left,right) in keys:
            outline = chr + '\t' + str(left) + '\t' + str(right)
            meth = []
            for i in range(left,right):
                if coverageDict[chr][i] != '':
                    meth.append(coverageDict[chr][i])
            if len(meth) > 0:
                AvMeth = np.mean(meth)
                outline = outline + '\t' + str(AvMeth/100) + '\t' + str(len(meth))
            else:
                outline = outline + '\t' + 'nan' + '\t' + 'nan'
            outfile.write(outline + '\n')
          
    outfile.close()
   
run()
