##################################
#                                #
# Last modified 2018/10/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
from scipy.stats import entropy
from scipy.stats import fisher_exact
from scipy.stats import beta
from scipy.stats import binom
from sklearn.metrics import normalized_mutual_info_score as NMIS
import random
import os
import math
from sets import Set

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s methylation_reads_all.tsv region.bed chrFieldID leftField rightFieldID minCoverage tabix_location outfileprefix' % sys.argv[0]
        print '\Note: the script assumes Tombo 1.3 probabilities, and a tabix indexed reads file'
        sys.exit(1)

    reads = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    minCov = int(sys.argv[6])
    tabix = sys.argv[7]
    outfilename = sys.argv[8]

    EMD = 2000
    if '-expectedMaxDist' in sys.argv:
        EMD = int(sys.argv[sys.argv.index('-expectedMaxDist') + 1])
        print 'will use an expected maximum distance of', EMD

    outfile = open(outfilename,'w')

    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    LC = 0
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        LC += 1
        if LC % 100 == 0:
            print LC, 'lines processed'
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[rightFieldID])
        ReadEntropies = []
        cmd1 = tabix + ' ' + reads + ' ' + chr + ':' + str(left) + '-' + str(right)
        pp = os.popen(cmd1, "r")
        lline = 'line'
        while lline != '':
            lline = pp.readline().strip()
            if lline == '':
                break
            rfields = lline.strip().split('\t')
            read_left = int(rfields[1])
            read_right = int(rfields[2])
            if read_left <= left and read_right >= right:
                pass
            else:
                continue
            cgs = rfields[6].split(',')
            loglike = rfields[7].split(',')
            t = zip(cgs,loglike)
            BinLs = list(int(round(float(y))) for x, y in t if int(x) >= left and int(x) < right)
            BinLLs = list(float(y)/len(BinLs) for y in BinLs)
            p1 = round(sum(BinLLs),5)
            p2 = 1 - p1
            E = entropy([p1,p2])
            ReadEntropies.append(E)
        if len(ReadEntropies) < minCov:
            continue
        outline = line.strip() + '\t' + str(sum(ReadEntropies)/len(ReadEntropies))
        print outline
        outfile.write(outline + '\n')

    outfile.close()
            
run()

