##################################
#                                #
# Last modified 2018/10/08       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import string
from sets import Set
import os
import subprocess

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def getstrand(FLAGfields):

    if 16 in FLAGfields:
        strand = '-'
    else:
        strand = '+'

    return(strand)

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s samtools read1_BAM(,BAM2,BAM3,...,) read2_BAM(,BAM2,BAM3,...,) bin_size outfile [-noFR] [-noInterchromosomalInteractions] [-dedup] [-first N_reads]' % sys.argv[0]
        print '\tthe [-noFR] option will exclude concordant fragments (which might be the result of unligated DNA; it only applies when the [-noInterchromosomalInteractions] option has been specified'
        print '\tthe input BAM files are assumed to be the product of the following Bowtie command: -p 1 -v 2 -k 2 -m 1'
        sys.exit(1)

    samtools = sys.argv[1]
    BAM1s = sys.argv[2].split(',')
    BAM2s = sys.argv[3].split(',')
    BS = int(sys.argv[4])
    outfilename = sys.argv[5]

    doFirstN = False
    if '-first' in sys.argv:
        doFirstN = True
        FN = int(sys.argv[sys.argv.index('-first') + 1])

    doNoFR = False
  
    doNoICIA = False
    if '-noInterchromosomalInteractions' in sys.argv:
        print 'will consider only intrachromosomal interactions'
        doNoICIA = True
        if '-noFR' in sys.argv:
            doNoFR = True
            print 'will exclude all FR reads'

    doDeDup = False
    if '-dedup' in sys.argv:
        doDeDup = True
        print 'will dedup fragments'
 
    InteractionMatrix = {}

    TotalReads = 0.0
    TA = 0.0

    for BID in range(len(BAM1s)):
        BAM1 = BAM1s[BID]
        BAM2 = BAM2s[BID]
        cmd1 = samtools + ' view ' + BAM1
        cmd2 = samtools + ' view ' + BAM2
        p1 = os.popen(cmd1, "r")
        p2 = os.popen(cmd2, "r")
        line1 = 'line'
        line2 = 'line'
        i=0
        while line1 != '':
            line1 = p1.readline()
            line2 = p2.readline()
            if line1 == '':
                continue
            i+=1
            if i % 1000000 == 0:
                print str(i/1000000) + 'M alignments processed in', BAM1, BAM2
            if doFirstN and i > FN:
                break
            fields1 = line1.strip().split('\t')
            fields2 = line2.strip().split('\t')
            ID1 = fields1[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0].split('_length=')[0]
            ID2 = fields2[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0].split('_length=')[0] 
            if ID1 != ID2:
                print 'files not properly sorted, exiting'
                print fields1
                print fields2
                print 'line number:', i
                sys.exit(1)
            chr1 = fields1[2] 
            chr2 = fields2[2] 
            if chr1 == '*' or chr2 == '*':
                continue
            if doNoICIA and chr1 != chr2:
                continue
            FLAGfields1 = FLAG(int(fields1[1]))
            FLAGfields2 = FLAG(int(fields2[1]))
            strand1 = getstrand(FLAGfields1)
            strand2 = getstrand(FLAGfields2)
            pos1 = int(fields1[3])
            pos2 = int(fields2[3])
            if strand1 == '-':
                pos1 = pos1 + len(fields1[9])
            if strand2 == '-':
                pos2 = pos2 + len(fields2[9])
            if doNoFR:
                if strand1 == '+' and strand2 == '-' and pos2 > pos1:
                    continue
                if strand1 == '-' and strand2 == '+' and pos1 > pos2:
                    continue
            bin1 = pos1 - (pos1 % BS)
            bin2 = pos2 - (pos2 % BS)
            B1 = min(bin1,bin2)
            B2 = max(bin1,bin2)
            if InteractionMatrix.has_key(chr1):
                pass
            else:
                InteractionMatrix[chr1] = {}
            if InteractionMatrix[chr1].has_key(chr2):
                pass
            else:
                InteractionMatrix[chr1][chr2] = {}
            if InteractionMatrix[chr1][chr2].has_key(B1):
                pass
            else:
                InteractionMatrix[chr1][chr2][B1] = {}
            if InteractionMatrix[chr1][chr2][B1].has_key(B2):
                pass
            else:
                if doDeDup:
                    InteractionMatrix[chr1][chr2][B1][B2] = []
                else:
                    InteractionMatrix[chr1][chr2][B1][B2] = 0
            if doDeDup:
                InteractionMatrix[chr1][chr2][B1][B2].append((pos1 - bin1, pos2 - bin2, strand1, strand2))
                TA += 1
            else:
                InteractionMatrix[chr1][chr2][B1][B2] += 1
                TotalReads += 1
                TA += 1

    print 'dedupping:'

    if doDeDup:
        chr1List = InteractionMatrix.keys()
        for chr1 in chr1List:
            print chr1
            chr2List = InteractionMatrix[chr1].keys()
            for chr2 in chr2List:
                B1s = InteractionMatrix[chr1][chr2].keys()
                for B1 in B1s:
                    B2s = InteractionMatrix[chr1][chr2][B1].keys()
                    for B2 in B2s:
#                        print InteractionMatrix[chr1][chr2][B1][B2]
#                        print len(InteractionMatrix[chr1][chr2][B1][B2]), len(Set(InteractionMatrix[chr1][chr2][B1][B2]))
                        R = len(Set(InteractionMatrix[chr1][chr2][B1][B2]))
                        InteractionMatrix[chr1][chr2][B1][B2] = R
                        TotalReads += R

    p1 = ''
    p2 = ''

    NormFactor = TotalReads/1000000

    outfile = open(outfilename,'w')

    print 'NormFactor:', NormFactor
    print 'total aligned read pairs:', TA
    outline = '#total aligned read pairs:\t' + str(TA)
    outfile.write(outline + '\n')
    print 'dedupped reads:', TotalReads
    outline = '#dedupped reads:\t' + str(TotalReads)
    outfile.write(outline + '\n')

    chr1List = InteractionMatrix.keys()
    chr1List.sort()

    for chr1 in chr1List:
        chr2List = InteractionMatrix[chr1].keys()
        chr2List.sort()
        for chr2 in chr2List:
            B1s = InteractionMatrix[chr1][chr2].keys()
            B1s.sort()
            for B1 in B1s:
                B2s = InteractionMatrix[chr1][chr2][B1].keys()
                B2s.sort()
                for B2 in B2s:
                    outline = chr1 + '\t' + str(B1) + '\t' + chr2 + '\t' + str(B2) + '\t' + str(InteractionMatrix[chr1][chr2][B1][B2]/NormFactor)
                    outfile.write(outline + '\n')

    outfile.close()
        
run()