##################################
#                                #
# Last modified 2018/08/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import random
import os
import math
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s methylation_reads_all.tsv genome.fa kmer_size outfilename [-regions filename chrFieldID leftFieldID rightFieldID]' % sys.argv[0]
        print '\Note: the script assumes Tombo 1.3 probabilities'
        sys.exit(1)

    reads = sys.argv[1]
    fasta = sys.argv[2]
    K = int(sys.argv[3])
    outfilename = sys.argv[4]

    doRegions = False
    if '-regions' in sys.argv:
        doRegions = True
        regions = sys.argv[sys.argv.index('-regions') + 1]
        regionsChrFieldID = int(sys.argv[sys.argv.index('-regions') + 2])
        regionsLeftFieldID = int(sys.argv[sys.argv.index('-regions') + 3])
        regionsRightFieldID = int(sys.argv[sys.argv.index('-regions') + 4])
        WantedDict = {}
        if regions.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + regions
        elif regions.endswith('.gz') or regions.endswith('.bgz'):
            cmd = 'zcat ' + regions
        elif regions.endswith('.zip'):
            cmd = 'unzip -p ' + regions
        else:
            cmd = 'cat ' + regions
        P = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = P.readline().strip()
            if line == '':
                break
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[regionsChrFieldID]
            L = int(fields[regionsLeftFieldID])
            R = int(fields[regionsRightFieldID])
            if WantedDict.has_key(chr):
                pass
            else:
                WantedDict[chr] = {}
            for i in range(L,R):
                WantedDict[chr][i] = 0
        print 'finished inputting regions'

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished inputting genomic sequence'

    KmerDict = {}
    Total = [0,0]

    if reads.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + reads
    elif reads.endswith('.gz') or reads.endswith('.bgz'):
        cmd = 'zcat ' + reads
    elif reads.endswith('.zip'):
        cmd = 'unzip -p ' + reads
    else:
        cmd = 'cat ' + reads
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    NLines = 0
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        NLines += 1
        if NLines % 100000 == 0:
            print NLines, 'lines processed'
        fields = line.strip().split('\t')
        chr = fields[0]
        if doRegions:
            if WantedDict.has_key(chr):
                pass
            else:
                continue
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        if len(fields) < 7:
            continue
        positions = fields[6].split(',')
        loglike = fields[7].split(',')
        if len(loglike) < 2:
            continue
        PP = list(int(x) for x in positions)
        LL = list(float(y) for y in loglike)
        for (p,l) in zip(PP,LL):
            if doRegions:
                if WantedDict[chr].has_key(p):
                    pass
                else:
                    continue
            if strand == '+':
                seq = GenomeDict[chr][p - K/2:p + K/2]
            if strand == '-':
                seq = GenomeDict[chr][p - K/2 + 1:p + K/2 + 1]
                seq = getReverseComplement(seq)
            if KmerDict.has_key(seq):
                pass
            else:
                KmerDict[seq] = [0,0]
            if l < 0.5:
                KmerDict[seq][0] += 1
                Total[0] += 1
            else:
                KmerDict[seq][1] += 1
                Total[1] += 1

    outfile = open(outfilename, 'w')

    Kmers = KmerDict.keys()
    Kmers.sort()

    ExpMeth = Total[0]/(Total[0] + Total[1] + 0.0)

    outline = '#kmer\tGC_perc\ttotal\tmeth\tunmet\tfraction_meth\tobs_vs_exp'
    outfile.write(outline + '\n')

    for seq in Kmers:
        ObsMeth = KmerDict[seq][0]/(KmerDict[seq][0] + KmerDict[seq][1] + 0.0)
        GC = (seq.count('C') + seq.count('G'))/(len(seq) + 0.0)
        outline = seq + '\t' + str(GC) + '\t' + str(KmerDict[seq][0] + KmerDict[seq][1]) + '\t' + str(KmerDict[seq][0]) + '\t' + str(KmerDict[seq][1]) + '\t' + str(ObsMeth) + '\t' + str(ObsMeth/ExpMeth)
        outfile.write(outline + '\n')

    outfile.close()
            
run()

