##################################
#                                #
# Last modified 2019/04/07       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

import sys
import string
import math
from sets import Set
import os
import subprocess

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s end1_BAM end2_BAM samtools chromSizes regionSize minDistance maxThickness outfile' % sys.argv[0]
        sys.exit(1)

    BAM1 = sys.argv[1]
    BAM2 = sys.argv[2]
    samtools = sys.argv[3]
    regionSize = int(sys.argv[5])
    minDist = int(sys.argv[6])
    maxThickness = float(sys.argv[7])
    outfilename = sys.argv[8]

    chrominfo=sys.argv[4]
    chromInfoDict={}
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoDict[chr] = end

    readDict = {}
    maxAlignedPos = {}

    LinkDict = {}

    cmd = samtools + ' view ' + BAM1
    p = os.popen(cmd, "r")
    line = p.readline()
    fields = line.strip().split('\t')
    readID = fields[0]
    chr = fields[2]
    LinkDict[chr] = {}
    FLAGfields = FLAG(int(fields[1]))
    if 16 in FLAGfields:
        strand = '-'
    else:
        strand = '+'
    pos = int(fields[3])
    if readDict.has_key(readID):
        pass
    else:
        readDict[readID]={}
        readDict[readID][1]=[]
    readDict[readID][1].append((chr,pos,strand))
    while line != '':
        line = p.readline()
        if line == '':
            continue
        fields = line.strip().split('\t')
        readID = fields[0]
        chr = fields[2]
        LinkDict[chr] = {}
        FLAGfields = FLAG(int(fields[1]))
        if 16 in FLAGfields:
            strand = '-'
        else:
            strand = '+'
        pos = int(fields[3])
        if readDict.has_key(readID):
            pass
        else:
            readDict[readID]={}
            readDict[readID][1]=[]
        readDict[readID][1].append((chr,pos,strand))

    cmd = samtools + ' view ' + BAM2
    p = os.popen(cmd, "r")
    line = p.readline()
    fields = line.strip().split('\t')
    readID = fields[0]
    chr = fields[2]
    LinkDict[chr] = {}
    FLAGfields = FLAG(int(fields[1]))
    if 16 in FLAGfields:
        strand = '-'
    else:
        strand = '+'
    pos = int(fields[3])
    if readDict.has_key(readID):
        pass
    else:
        readDict[readID]={}
        readDict[readID][2]=[]
    if readDict[readID].has_key(2):
        pass
    else:
        readDict[readID][2]=[]
    readDict[readID][2].append((chr,pos,strand))
    while line != '':
        line = p.readline()
        if line == '':
            continue
        fields = line.strip().split('\t')
        readID = fields[0]
        chr = fields[2]
        LinkDict[chr] = {}
        FLAGfields = FLAG(int(fields[1]))
        if 16 in FLAGfields:
            strand = '-'
        else:
            strand = '+'
        pos = int(fields[3])
        if readDict.has_key(readID):
            pass
        else:
            readDict[readID]={}
            readDict[readID][2]=[]
        if readDict[readID].has_key(2):
            pass
        else:
            readDict[readID][2]=[]
        readDict[readID][2].append((chr,pos,strand))

    for readID in readDict:
        if readDict[readID].has_key(1) and readDict[readID].has_key(2):
            pass
        else:
            continue
        for (chr1,pos1,strand1) in readDict[readID][1]:
            for (chr2,pos2,strand2) in readDict[readID][2]:
                a = [(chr1,pos1,strand1),(chr2,pos2,strand2)]
                a.sort()
                chrA = a[0][0]
                chrB = a[1][0]
                posA = a[0][1] - a[0][1] % regionSize
                posB = a[1][1] - a[1][1] % regionSize
                if chrA != chrB:
                    continue
                if math.fabs(posA - posB) < minDist:
                    continue
                if (min(posA,posB) + chromInfoDict[chrA] - max(posA,posB)) < minDist:
                    continue
                if LinkDict[chrA].has_key(posA):
                    pass
                else:
                    LinkDict[chrA][posA] = {}
                if LinkDict[chrA][posA].has_key(posB):
                    pass
                else:
                    LinkDict[chrA][posA][posB] = {}
                    LinkDict[chrA][posA][posB][('+','+')] = 0
                    LinkDict[chrA][posA][posB][('+','-')] = 0
                    LinkDict[chrA][posA][posB][('-','+')] = 0
                    LinkDict[chrA][posA][posB][('-','-')] = 0
                    LinkDict[chrA][posA][posB][('total')] = 0
                LinkDict[chrA][posA][posB][(a[0][2],a[1][2])] += 1
                LinkDict[chrA][posA][posB]['total'] += 1

    outfile = open(outfilename,'w')

    for chr in LinkDict.keys():
        maxScore = 0
        posAs = LinkDict[chr].keys()
        posAs.sort()
        for posA in posAs:
            posBs = LinkDict[chr][posA].keys()
            for posB in posBs:
                if maxScore < LinkDict[chr][posA][posB]['total']:
                    maxScore = LinkDict[chr][posA][posB]['total']
        ScalingFactor = maxScore/maxThickness
        outfile.write('#maxScore = ' + str(maxScore) + '\n')
        outfile.write('#chr\tpos1\tpos1 + bin\tchr\tpos2\tpos2 + bin\tthickness\t++\t+-\t-+\t--\n')
        for posA in posAs:
            posBs = LinkDict[chr][posA].keys()
            posBs.sort()
            for posB in posBs:
                if (LinkDict[chr][posA][posB]['total']/ScalingFactor - (LinkDict[chr][posA][posB]['total']/ScalingFactor)%1) == 0:
                    continue
                outline = chr + '\t' + str(posA) + '\t'+ str(posA + regionSize) + '\t' + chr + '\t' + str(posB) + '\t'+ str(posB + regionSize) + '\tthickness=' + str(LinkDict[chr][posA][posB]['total']/ScalingFactor - (LinkDict[chr][posA][posB]['total']/ScalingFactor)%1)
                outline = outline + '\t' + str(LinkDict[chrA][posA][posB][('+','+')])
                outline = outline + '\t' + str(LinkDict[chrA][posA][posB][('+','-')])
                outline = outline + '\t' + str(LinkDict[chrA][posA][posB][('-','+')])
                outline = outline + '\t' + str(LinkDict[chrA][posA][posB][('-','-')])
                outfile.write(outline + '\n')

    outfile.close()

run()