##################################
#                                #
# Last modified 2025/02/24       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s BAM1 BAM2 samtools outfile' % sys.argv[0]
        sys.exit(1)

    BAM1 = sys.argv[1]
    BAM2 = sys.argv[2]
    samtools = sys.argv[3]
    outfilename = sys.argv[4]

    readDict = {}
    
    cmd1 = samtools + ' view ' + BAM1
    print cmd1

    p = os.popen(cmd1, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        fields = line.strip().split('\t')
        ID = fields[0].split(' ')[0]
        if 16 in FLAG(int(fields[1])):
            strand = 1
        else:
            strand = 0
        chr = fields[2]
        pos = fields[3]
        if readDict.has_key(ID):
            pass
        else:
            readDict[ID]={}
            readDict[ID][1]=[]
            readDict[ID][2]=[]
        readDict[ID][1].append((chr,pos,strand))

    print 'finished inputting', BAM1

    cmd2 = samtools + ' view ' + BAM2
    print cmd2

    p = os.popen(cmd2, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        fields = line.strip().split('\t')
        ID = fields[0].split(' ')[0]
        if 16 in FLAG(int(fields[1])):
            strand = 1
        else:
            strand = 0
        chr = fields[2]
        pos = int(fields[3])
        if readDict.has_key(ID):
            pass
        else:
            continue
        readDict[ID][2].append((chr,pos,strand))

    print 'finished inputting', BAM2

    print 'found', len(readDict.keys()), 'reads with alignments in both BAM files'

    outfile = open(outfilename, 'w')

    outputDict = {}
    for ID in readDict.keys():
        if len(readDict[ID][1]) == 1 and len(readDict[ID][2]) == 1:
            pass
        else:
            continue
        alignments = readDict[ID][1] + readDict[ID][2]
        alignments.sort()
        chr1 = str(alignments[0][0])
        pos1 = str(alignments[0][1])
        str1 = str(alignments[0][2])
        chr2 = str(alignments[1][0])
        pos2 = str(alignments[1][1])
        str2 = str(alignments[1][2])
        outputDict[(chr1,pos1,str1,chr2,pos2,str2)] = 0

    outputDictList = outputDict.keys()
    outputDictList.sort()

    print 'finished compiling deduplicated alignments'

    lineslist = open(outfilename)
    i=0
    for (chr1,pos1,str1,chr2,pos2,str2) in outputDictList:
        i+=1
        if i % 1000000 == 0:
            print str(i/1000000) + 'M reads processed'
        outline = str1 + ' ' + chr1 + ' ' + pos1 + ' 0 ' + str2 + ' ' + chr2 + ' ' + pos2 + ' 1'
        outfile.write(outline + '\n')

    outfile.close()

run()
