##################################
#                                #
# Last modified 2023/05/13       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import random
import os
from sets import Set
import re

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','X':'X'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s per_read_modified_base_calls.txt genome.fa sequenceContext(s) outfile [-flipSign] [-excludeContext string(,string2,string3,...,stringN) radius]' % sys.argv[0]
        print 'note: the sequence context would usually be A or CG, or both; for generalized sequence contexts, the modified position is understood to be the first one in the sequence'
        sys.exit(1)

    reads = sys.argv[1]
    fasta = sys.argv[2]
    SEQ = sys.argv[3].upper()
    seqR = len(SEQ)
    outfilename = sys.argv[4]

    doFS = False
    if '-flipSign' in sys.argv:
        doFS = True
        print 'will flip signs'

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished inputting genomic sequence'

    doExcludeContext = False
    if '-excludeContext' in sys.argv:
        doExcludeContext = True

        contexts = []
        for cont in sys.argv[sys.argv.index('-excludeContext') + 1].split(','):
            contexts.append(cont.upper())
        radius = int(sys.argv[sys.argv.index('-excludeContext') + 2])
        print 'will filter out the following sequence contexts:', contexts, 'with a radius of', radius, 'bp around each match'

        CHROMOSOMES = GenomeDict.keys()
        CHROMOSOMES.sort()
        for chr in CHROMOSOMES:
            print 'masking', chr
            ToBeMasked = {}
            for C in contexts:
                print C, chr, 
                for m in re.finditer(C,GenomeDict[chr]):
                    pos1 = m.start()
                    pos2 = m.end()
                    for i in range(pos1-radius,pos2+radius):
                        ToBeMasked[i] = 1
            print len(ToBeMasked), len(GenomeDict[chr]), len(ToBeMasked)/(len(GenomeDict[chr]) + 0.0)
            NewSeq = []
            for i in range(len(GenomeDict[chr])):
                if ToBeMasked.has_key(i):
                    NewSeq.append('X')
                else:
                    NewSeq.append(GenomeDict[chr][i])
            GenomeDict[chr] = ''.join(NewSeq)

        print 'finished masking genomic sequence'

    ReadDict = {}
    
    if reads.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + reads
    elif reads.endswith('.gz'):
        cmd = 'gunzip -c ' + reads
    elif reads.endswith('.zip'):
        cmd = 'unzip -p ' + reads
    else:
        cmd = 'cat ' + reads
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('read_id\tchrm\t'):
            continue
        fields = line.strip().split('\t')
        RN += 1
        if RN % 1000000 == 0:
            print str(RN/1000000.) +  'M lines processed'
        chr = fields[1]
        strand = fields[2]
        pos = int(fields[3])
        if strand == '+':
            if GenomeDict[chr][pos:pos+seqR] != SEQ:
                continue
        if strand == '-':
            if getReverseComplement(GenomeDict[chr][pos-seqR+1:pos+1]) != SEQ:
                continue
        if doExcludeContext:
#            print chr, pos, strand, GenomeDict[chr][pos], read
            if GenomeDict[chr][pos] == 'X':
#                print chr, pos, strand, GenomeDict[chr][pos-2:pos+2]
                continue
        read = fields[0]
        loglike = fields[4]
        if doFS:
            loglike = str((-1)*float(loglike))
        if ReadDict.has_key(chr):
            pass
        else:
            ReadDict[chr] = {}
        if ReadDict[chr].has_key(read):
            pass
        else:
            ReadDict[chr][read] = {}
            ReadDict[chr][read]['ps'] = []
            ReadDict[chr][read]['lls'] = []
        ReadDict[chr][read]['strand'] = strand
        ReadDict[chr][read]['ps'].append(pos)
        ReadDict[chr][read]['lls'].append(loglike)

    print 'finished inputting reads'

    chromosomes = ReadDict.keys()
    chromosomes.sort()

    outfile = open(outfilename,'w')

    K=0
    for chr in chromosomes:
        print chr
        for readID in ReadDict[chr].keys():
            K+=1
            if K % 100000 == 0:
                print K
            left = min(ReadDict[chr][readID]['ps'])
            right = max(ReadDict[chr][readID]['ps'])
            strand = ReadDict[chr][readID]['strand']
            outline = chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + readID + '\t' + 'nan' + '\t'
            if strand == '-':
                ReadDict[chr][readID]['ps'].reverse()
                ReadDict[chr][readID]['lls'].reverse()
            for p in ReadDict[chr][readID]['ps']:
                outline = outline + str(p) + ','
            outline = outline[0:-1] + '\t'
            for L in ReadDict[chr][readID]['lls']:
                outline = outline + str(L) + ','
            outline = outline[0:-1]
            outfile.write(outline + '\n')
            
    outfile.close()

            
run()

