##################################
#                                #
# Last modified 2025/05/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import numpy as np

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s CX_report.txt.gz genome.fa CpG|GpC|both|C radius [-MethylDackel] [-minCov N]' % sys.argv[0]
        print '\tInput track can be compressed'
        print '\tthe script will print to stdout by default'
        sys.exit(1)

    input = sys.argv[1]
    fasta = sys.argv[2]
    Ccontext = sys.argv[3]
    radius = int(sys.argv[4])

    doMD = False
    if '-MethylDackel' in sys.argv:
        doMD = True
#        print 'will assume the following format:'
#        print '\tchr\tleft\tright\tMethPerc\tMeth\tUnmeth'

    minCov = 1
    if '-minCov' in sys.argv:
        minCov = int(sys.argv[sys.argv.index('-minCov') + 1])

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('gz'):
        cmd = 'gunzip -c ' + input
    elif input.endswith('.zip'):
        cmd = 'unzip -p ' + input
    else:
        cmd = 'cat ' + input
    p = os.popen(cmd, "r")
    line = 'line'
    currentChr = ''
    currentWindow = 0
    currentWindowValues = []
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('track'):
            continue
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        strand = fields[1]
        pos = int(fields[1]) - 1
        if strand == '+':
            if Ccontext == 'CpG':
                seqContext = GenomeDict[chr][pos:pos+1]
                if seqContext != 'CG':
                    continue
            if Ccontext == 'GpC':
                seqContext = GenomeDict[chr][pos-1:pos]
                if seqContext != 'GC':
                    continue
            if Ccontext == 'both':
                seqContext1 = GenomeDict[chr][pos-1:pos]
                seqContext2 = GenomeDict[chr][pos:pos+1]
                if seqContext1 != 'GC' and seqContext2 != 'CG':
                    continue
            if Ccontext == 'C':
                pass
        if strand == '-':
            if Ccontext == 'CpG':
                seqContext = getReverseComplement(GenomeDict[chr][pos-1:pos])
                if seqContext != 'CG':
                    continue
            if Ccontext == 'GpC':
                seqContext = getReverseComplement(GenomeDict[chr][pos:pos+1])
                if seqContext != 'GC':
                    continue
            if Ccontext == 'both':
                seqContext1 = getReverseComplement(GenomeDict[chr][pos:pos+1])
                seqContext2 = getReverseComplement(GenomeDict[chr][pos-1:pos])
                if seqContext1 != 'GC' and seqContext2 != 'CG':
                    continue
            if Ccontext == 'C':
                pass
        if doMD:
            M = int(fields[4])
            unM = int(fields[5])
        else:
            M = int(fields[3])
            unM = int(fields[4])
        if M + unM >= minCov:
            pass
        else:
            continue
        window = (pos/radius)*radius
        if currentChr == '':
            currentChr = chr
            currentWindow = window
            currentWindowValuesM = []
            currentWindowValuesM.append(M)
            currentWindowValuesU = []
            currentWindowValuesU.append(unM)
        elif currentChr != chr or window != currentWindow:
            outline = currentChr + '\t' + str(max(0,currentWindow)) + '\t' + str(min(currentWindow + radius,len(GenomeDict[currentChr]))) + '\t' + str(sum(currentWindowValuesM)/(sum(currentWindowValuesU) + sum(currentWindowValuesM) + 0.0))
            print outline
            currentChr = chr
            currentWindow = window
            currentWindowValuesM = []
            currentWindowValuesM.append(M)
            currentWindowValuesU = []
            currentWindowValuesU.append(unM)
        else:
            currentWindowValuesM.append(M)
            currentWindowValuesU.append(unM)

    outline = currentChr + '\t' + str(max(0,currentWindow)) + '\t' + str(min(currentWindow + radius,len(GenomeDict[currentChr]))) + '\t' + str(np.mean(currentWindowValues))
    print outline

run()
