##################################
#                                #
# Last modified 2019/04/19       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy as np
import os
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s bedfilename chrField bismark.cov.gz minCov outputfilename [-CpGonly genome.fa]' % sys.argv[0]
        print '\tNote: this script will output the combined methylation average over all nucleotides in the regions'
        sys.exit(1)
    
    bed = sys.argv[1]
    fieldID = int(sys.argv[2])
    input = sys.argv[3]
    minCov = int(sys.argv[4])
    outfilename = sys.argv[5]

    doCpGonly = False
    if '-CpGonly' in sys.argv:
        doCpGonly = True
        print 'will only include Cs in CpG context'
        fasta = sys.argv[sys.argv.index('-CpGonly') + 1]
        GenomeDict={}
        sequence=''
        inputdatafile = open(fasta)
        for line in inputdatafile:
            if line[0]=='>':
                if sequence != '':
                    GenomeDict[chr] = ''.join(sequence).upper()
                chr = line.strip().split('>')[1]
                print chr
                sequence=[]
                Keep=False
                continue
            else:
                sequence.append(line.strip())
        GenomeDict[chr] = ''.join(sequence).upper()

        print 'finished inputting genomic sequence'

    regionDict={}
    coverageDict={}
    
    lineslist = open(bed)
    l=0
    for line in lineslist:
        l+=1
        if l % 1000 == 0:
            print l, 'lines processed'
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        chr = fields[fieldID]
        left = int(fields[fieldID+1])
        right = int(fields[fieldID+2])
        if regionDict.has_key(chr):
            pass
        else:
            regionDict[chr] = {}
            coverageDict[chr] = {}
        regionDict[chr][(chr,left,right)] = ''
        for i in range(left,right):
            coverageDict[chr][i] = ''

#    print regionDict.keys()
#    print coverageDict.keys()

    print 'finished inputing bed file, processing methylation coverage file'

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('gz'):
        cmd = 'gunzip -c ' + input
    elif input.endswith('.zip'):
        cmd = 'unzip -p ' + input
    else:
        cmd = 'cat ' + input
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        i+=1
        if i % 10000000 == 0:
            print str(i/1000000) + 'M lines processed'  
        if line.startswith('#') or line.startswith('track'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        pos = int(fields[1])
        if doCpGonly:
            if GenomeDict[chr][pos-1:pos+1] != 'CG' and GenomeDict[chr][pos-2:pos] != 'CG':
                continue
        if coverageDict.has_key(chr) and coverageDict[chr].has_key(pos):
            coverageDict[chr][pos] = (int(fields[4]),int(fields[5]))

    outfile = open(outfilename, 'w')

    chrkeys=regionDict.keys()
    chrkeys.sort()

    outfile.write('#chr\tleft\tright\tAverageMethylation%\tNumReadsPos\n')

    for chr in chrkeys:
        keys = regionDict[chr].keys()
        keys.sort()
        for (chr,left,right) in keys:
            outline = chr + '\t' + str(left) + '\t' + str(right)
            M = 0.0
            U = 0.0
            for i in range(left,right):
                if coverageDict[chr][i] != '':
                    M += coverageDict[chr][i][0]
                    U += coverageDict[chr][i][1]
            if M+U >= minCov:
                AvMeth = M/(U+M)
                outline = outline + '\t' + str(AvMeth) + '\t' + str(U+M)
            else:
                outline = outline + '\t' + 'nan' + '\t' + 'nan'
            outfile.write(outline + '\n')
          
    outfile.close()
   
run()
