##################################
#                                #
# Last modified 2018/11/25       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats import beta
from scipy.stats import binom
from scipy.stats import norm
import random
import os
import math
from sets import Set
from sklearn.metrics import normalized_mutual_info_score as NMIS

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s methylation_reads_all.tsv peaks chrFieldID posFieldID radius window tabix_location outfile' % sys.argv[0]
        print '\Note: the script assumes Tombo 1.3 probabilities, a tabix indexed reads file, and uses a beta distribution prior of (10,10) by default'
        sys.exit(1)

    reads = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    posFieldID = int(sys.argv[4])
    R = int(sys.argv[5])
    W = int(sys.argv[6])
    tabix = sys.argv[7]
    outfilename = sys.argv[8]

    alph = 10
    bet = 10
    PSS = 100

    outfile = open(outfilename,'w')
    outline = '#peak'
    for i in range(-R,R,W):
        outline = outline + '\t' + str(i)
    outfile.write(outline + '\n')

    PeakDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    peakline = 'line'
    while peakline != '':
        peakline = P.readline().strip()
        if peakline == '':
            break
        if peakline.startswith('#'):
            continue
        fields = peakline.strip().split('\t')
        chr = fields[chrFieldID]
        pos = int(fields[posFieldID])
        cmd = tabix + ' ' + reads + ' ' + chr + ':' + str(pos - R) + '-' + str(pos + R)
        p2 = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = p2.readline().strip()
            if line == '':
                break
            fields = line.strip().split('\t')
            read_left = int(fields[1])
            read_right = int(fields[2])
            RN += 1
            if RN % 100000 == 0:
                print RN, 'reads processed'
            if read_left <= pos - R and read_right >= pos + R:
                pass
            else:
                continue
            strand = fields[3]
            read = fields[4]
            cgs = fields[6].split(',')
            loglike = fields[7].split(',')
            t = zip(cgs,loglike)
            RD = dict((int(x), float(y)) for x, y in t)
            outline = chr + ':' + str(pos)
            for i in range(pos - R, pos + R, W):
                (A,B) = (alph,bet)
#                print i, beta.mean(A,B)
                for j in range(i, i + W):
                    if RD.has_key(j):
                        p = RD[j]
                        Z = int(PSS*p)
                        A = A + Z
                        B = B + PSS - Z
#                print i, beta.mean(A,B)
                if beta.mean(A,B) < 0.5:
                    final_p = 1
                else:
                    final_p = 0
                outline = outline + '\t' + str(final_p)
            outfile.write(outline + '\n')

    outfile.close()
            
run()

