##################################
#                                #
# Last modified 2018/08/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import random
import os
import math
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s methylation_reads_all.tsv radius outfilename [-regions filename chrFieldID leftFieldID rightFieldID]' % sys.argv[0]
        print '\Note: the script assumes Tombo 1.3 probabilities'
        sys.exit(1)

    reads = sys.argv[1]
    radius = int(sys.argv[2])
    outfilename = sys.argv[3]

    doRegions = False
    if '-regions' in sys.argv:
        doRegions = True
        regions = sys.argv[sys.argv.index('-regions') + 1]
        regionsChrFieldID = int(sys.argv[sys.argv.index('-regions') + 2])
        regionsLeftFieldID = int(sys.argv[sys.argv.index('-regions') + 3])
        regionsRightFieldID = int(sys.argv[sys.argv.index('-regions') + 4])
        WantedDict = {}
        if regions.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + regions
        elif regions.endswith('.gz') or regions.endswith('.bgz'):
            cmd = 'zcat ' + regions
        elif regions.endswith('.zip'):
            cmd = 'unzip -p ' + regions
        else:
            cmd = 'cat ' + regions
        P = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = P.readline().strip()
            if line == '':
                break
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[regionsChrFieldID]
            L = int(fields[regionsLeftFieldID])
            R = int(fields[regionsRightFieldID])
            if WantedDict.has_key(chr):
                pass
            else:
                WantedDict[chr] = {}
            for i in range(L,R):
                WantedDict[chr][i] = 0
        print 'finished inputting regions'

    DistDict = {}
    for i in range(-radius,radius+1):
        DistDict[i] = [0,0]

    if reads.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + reads
    elif reads.endswith('.gz') or reads.endswith('.bgz'):
        cmd = 'zcat ' + reads
    elif reads.endswith('.zip'):
        cmd = 'unzip -p ' + reads
    else:
        cmd = 'cat ' + reads
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    NLines = 0
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        NLines += 1
        if NLines % 10000 == 0:
            print NLines, 'lines processed'
        fields = line.strip().split('\t')
        chr = fields[0]
        if doRegions:
            if WantedDict.has_key(chr):
                pass
            else:
                continue
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        if len(fields) < 7:
            continue
        positions = fields[6].split(',')
        loglike = fields[7].split(',')
        if len(loglike) < 2:
            continue
        t = zip(positions,loglike)
        RD = dict((int(x), float(y)) for x, y in t)
        for pos in RD.keys():
            if doRegions:
                if WantedDict[chr].has_key(pos):
                    pass
                else:
                    continue
            l = RD[pos]
            if l < 0.5:
                M = 1
            else:
                M = 0
            for npos in range(pos - radius, pos + radius):
                if RD.has_key(npos):
                    if RD[npos] < 0.5:
                        Mnpos = 1
                    else:
                        Mnpos = 0
                else:
                    continue
                if strand == '+':
                    d = pos - npos
                if strand == '-':
                    d = npos - pos
                if Mnpos == M:
                    DistDict[d][0] += 1
                else:
                    DistDict[d][1] += 1

    outfile = open(outfilename, 'w')

    Ds = DistDict.keys()
    Ds.sort()

    outline = '#Pos\ttotal\tmatching_methylation_status\tnot_matching_methylation_status\tfraction_matching'
    outfile.write(outline + '\n')

    for D in Ds:
        if DistDict[D][0] + DistDict[D][1] > 0:
            outline = str(D) + '\t' + str(DistDict[D][0] + DistDict[D][1]) + '\t' + str(DistDict[D][0]) + '\t' + str(DistDict[D][1]) + '\t' + str(DistDict[D][0]/(DistDict[D][0] + DistDict[D][1] + 0.0))
            outfile.write(outline + '\n')

    outfile.close()
            
run()

