##################################
#                                #
# Last modified 2025/05/24       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import numpy as np
from scipy.stats import entropy
from scipy.stats import fisher_exact
from scipy.stats import beta
from scipy.stats import binom
from sklearn.metrics import normalized_mutual_info_score as NMIS
import random
import os
import math
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s bam1,bam2,...,bamN region.bed chrFieldID leftField rightFieldID minCoverage outfilename ' % sys.argv[0]
        sys.exit(1)

    BAMfiles = sys.argv[1].split(',')
    fasta = sys.argv[2]
    peaks = sys.argv[3]
    chrFieldID = int(sys.argv[4])
    leftFieldID = int(sys.argv[5])
    rightFieldID = int(sys.argv[6])
    minCov = int(sys.argv[7])
    outfilename = sys.argv[8]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished inputting genomic sequence'


#    SS = 1
#    doSS = False
#    if '-subsample' in sys.argv:
#        SS = int(sys.argv[sys.argv.index('-subsample') + 1])
#        doSS = True
#        print 'will subsample all comparisons down to', minCov, 'reads'
#        print 'will take the average outcome of', SS, 'subsamplings'

#    EMD = 2000
#    if '-expectedMaxDist' in sys.argv:
#        EMD = int(sys.argv[sys.argv.index('-expectedMaxDist') + 1])
#        print 'will use an expected maximum distance of', EMD

    PeakDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = P.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        RL = int(fields[leftFieldID])
        RR = int(fields[rightFieldID])
        if PeakDict.has_key(chr):
            pass
        else:
            PeakDict[chr] = []
        PeakDict[chr].append((RL,RR))

    print 'finished inputting peaks'

    chromosomes = PeakDict.keys()
    chromosomes.sort()

    outfile = open(outfilename,'w')
    outline = '#chr\tleft\tright\tNMIscores'
    outfile.write(outline + '\n')

    for chr in chromosomes:
        PeakDict[chr].sort()
        ReadDict = {}
        Matrix = {}
        for BAM in BAMfiles:
            print chr, BAM
            samfile = pysam.Samfile(BAM, "rb" )
            rk = 0
            for (left,right) in PeakDict[chr]:
                for alignedread in samfile.fetch(chr, left, right):
                    rk += 1
                    if rk % 100000 == 0:
                        print rk
                    fields=str(alignedread).split('\t')
                    ID = fields[0]
                    if ReadDict.has_key(ID):
                        pass
                    else:
                        ReadDict[ID] = []
                    FLAGfields = FLAG(int(fields[1]))
                    pos = alignedread.pos - 1
                    readseq_temp = alignedread.seq
                    readseq = ''
                    rpos = 0
                    for (m,bp) in alignedread.cigar:
# soft-clipped bases:
                        if m == 4:
                            rpos += bp
# matches:
                        if m == 0:
                            readseq += readseq_temp[rpos:rpos+bp]
                            rpos += bp
# insertions:
# note: not handled properly, as the junction remaining after excising the insertion might be a CG or GC
# but there is no good way to deal with it
                        if m == 1:
                            rpos += bp
# deletions:
                        if m == 2:
                            for D in range(bp):
                                readseq += 'N'
                    if alignedread.is_reverse:
                        s = '-'
                    else:
                        s = '+'
                    posScores = []
                    for i in range(pos,pos + len(readseq)):
                        if i < left or i >= right:
                            continue
                        if s == '+' and GenomeDict[chr][i:i+1] == 'C':
                            if readseq[i-pos-1:i-pos] == 'C':
                                posScores.append((i,1))
                            else:
                                posScores.append((i,0))
                        if s == '-' and GenomeDict[chr][i:i+1] == 'G':
                            if readseq[i-pos-1:i-pos] == 'G':
                                posScores.append((i,1))
                            else:
                                posScores.append((i,0))
#                    ReadDict[ID].append((FLAGfields,pos,posScores,s))
                    for (i,b1) in posScores:
                        if Matrix.has_key(i):
                            pass
                        else:
                            Matrix[i] = {}
                        for (j,b2) in posScores:
                            if Matrix[i].has_key(j):
                                pass
                            else:
                                Matrix[i][j] = {}
                                Matrix[i][j][0] = []
                                Matrix[i][j][1] = []
                            Matrix[i][j][0].append(b1)
                            Matrix[i][j][1].append(b2)
                        
        positions = Matrix.keys()
        positions.sort()
        for i in positions:
            outline = chr + '\t' + str(i) + '\t' + str(i+1) + '\t'
            positions2 = Matrix[i].keys()
            positions2.sort()
            for j in positions2:
                if j < i:
                    continue
                P = np.array(Matrix[i][j][0])
                if len(P) < minCov:
                    continue
                Q = np.array(Matrix[i][j][1])
                NMI = NMIS(P,Q,average_method='arithmetic')
                if NMI > 2:
                    print P
                    print Q
                    print NMI
                    sys.exit(1)
                matches = np.sum(P==Q)
                if matches < 0.5*len(P):
                    NMI = -NMI
                outline += str(j-i) + ':' + str(round(NMI,3)) + ';'
            if outline.endswith('\t'):
                continue
            outfile.write(outline[:-1] + '\n')

    outfile.close()
            
run()

