##################################
#                                #
# Last modified 2018/03/07       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import string
from sets import Set
import os
import subprocess

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def getstrand(FLAGfields):

    if 16 in FLAGfields:
        strand = '-'
    else:
        strand = '+'

    return(strand)

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s samtools read1_BAM(,BAM2,BAM3,...,) read2_BAM(,BAM2,BAM3,...,) bin_size outfile [-first N_reads]' % sys.argv[0]
        print '\tthe input BAM files are assumed to be the product of the following Bowtie command: -p 1 -v 2 -k 2 -m 1'
        sys.exit(1)

    samtools = sys.argv[1]
    BAM1s = sys.argv[2].split(',')
    BAM2s = sys.argv[3].split(',')
    BS = int(sys.argv[4])
    outfilename = sys.argv[5]

    doFirstN = False
    if '-first' in sys.argv:
        doFirstN = True
        FN = int(sys.argv[sys.argv.index('-first') + 1])

    DistanceDict = {}
    DistanceDict['InterChromosomal'] = {}
    DistanceDict['InterChromosomal']['L'] = []
    DistanceDict['InterChromosomal']['I'] = []
    DistanceDict['InterChromosomal']['O'] = []
    DistanceDict['InterChromosomal']['R'] = []

    TotalReads = 0

    for BID in range(len(BAM1s)):
        BAM1 = BAM1s[BID]
        BAM2 = BAM2s[BID]
        cmd1 = samtools + ' view ' + BAM1
        cmd2 = samtools + ' view ' + BAM2
        p1 = os.popen(cmd1, "r")
        p2 = os.popen(cmd2, "r")
        line1 = 'line'
        line2 = 'line'
        i=0
        while line1 != '':
            line1 = p1.readline()
            line2 = p2.readline()
            if line1 == '':
                continue
            i+=1
            if i % 1000000 == 0:
                print str(i/1000000) + 'M alignments processed in', BAM1, BAM2
            if doFirstN and i > FN:
                break
            fields1 = line1.strip().split('\t')
            fields2 = line2.strip().split('\t')
            ID1 = fields1[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0].split('_length=')[0]
            ID2 = fields2[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0].split('_length=')[0] 
            if ID1 != ID2:
                print 'files not properly sorted, exiting'
                print fields1
                print fields2
                print 'line number:', i
                sys.exit(1)
            chr1 = fields1[2] 
            chr2 = fields2[2] 
            if chr1 == '*' or chr2 == '*':
                continue
            FLAGfields1 = FLAG(int(fields1[1]))
            FLAGfields2 = FLAG(int(fields2[1]))
            strand1 = getstrand(FLAGfields1)
            strand2 = getstrand(FLAGfields2)
            pos1 = int(fields1[3])
            pos2 = int(fields2[3])
            if strand1 == '-':
                pos1 = pos1 + len(fields1[9])
            if strand2 == '-':
                pos2 = pos2 + len(fields2[9])
            if pos2 > pos1:
                pp1 = pos1
                pp2 = pos2
                s1 = strand1
                s2 = strand2
            else:
                pp1 = pos2
                pp2 = pos1
                s1 = strand2
                s2 = strand1
            if s1 == '+' and s2 == '+':
                type = 'L'
            if s1 == '+' and s2 == '-':
                type = 'I'
            if s1 == '-' and s2 == '+':
                type = 'O'
            if s1 == '-' and s2 == '-':
                type = 'R'
            TotalReads += 1
            if chr1 != chr2:
                DistanceDict['InterChromosomal'][type].append((chr1,pos1,chr2,pos2))
            else:
                bin = pp2 - pp1 - ((pp2 - pp1) % BS)
                if DistanceDict.has_key(bin):
                    pass
                else:
                     DistanceDict[bin] = {}
                     DistanceDict[bin]['L'] = []
                     DistanceDict[bin]['I'] = []
                     DistanceDict[bin]['O'] = []
                     DistanceDict[bin]['R'] = []
                DistanceDict[bin][type].append((chr1,pos1,chr2,pos2))

    print 'dedupping:'

    TotalDedupedReads = 0.0

    for bin in DistanceDict.keys():
        for type in DistanceDict[bin].keys():
            DistanceDict[bin][type] = len(Set(DistanceDict[bin][type]))
            TotalDedupedReads += DistanceDict[bin][type]

    NormFactor = TotalDedupedReads/1000000

    for bin in DistanceDict.keys():
        for type in DistanceDict[bin].keys():
            DistanceDict[bin][type] = DistanceDict[bin][type]/NormFactor

    outfile = open(outfilename,'w')

    print 'total aligned read pairs:', TotalReads
    outline = '#total aligned read pairs:\t' + str(TotalReads)
    outfile.write(outline + '\n')
    print 'dedupped reads:', TotalDedupedReads
    outline = '#dedupped reads:\t' + str(TotalDedupedReads)
    outfile.write(outline + '\n')

    outline = '#distance\ttotal\tL\tI\tO\tR'
    outfile.write(outline + '\n')

    bins = DistanceDict.keys()
    bins.sort()
    for bin in bins:
        L = DistanceDict[bin]['L']
        I = DistanceDict[bin]['I']
        O = DistanceDict[bin]['O']
        R = DistanceDict[bin]['R']
        outline = str(bin) + '\t' + str(L + I + O + R) + '\t' + str(L/(L + I + O + R)) + '\t' + str(I/(L + I + O + R))
        outline = outline + '\t' + str(O/(L + I + O + R)) + '\t' + str(R/(L + I + O + R))
        outfile.write(outline + '\n')

    outfile.close()
        
run()