##################################
#                                #
# Last modified 05/10/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
from sets import Set

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s max_fragment_length merge_radius outfile' % sys.argv[0]
        print '\tNote: the script assumes standard input from samtools view'
        print '\tNote: the script assumes uniform read length'
        sys.exit(1)

    max_fragment_length = int(sys.argv[1])
    merge_radius = int(sys.argv[1])
    outputfilename = sys.argv[3]

    DiscordanrPairs = {}

    lineslist  = sys.stdin
    c=0
    for line in lineslist:
        c+=1
        if c % 1000000 == 0:
            print str(c/1000000) + 'M alignments processed'
        if line.startswith('#'):
            continue
        fields = line.split('\t')
        ID = fields[0]
        FLAGlist = FLAG(int(fields[1]))
        if 128 in FLAGlist:
            continue
        readLength = len(fields[9])
        chr1 = fields[2]
        pos1 = int(fields[3])
        chr2 = fields[6]
        pos2 = int(fields[7])
        if fields[6] == '=':
            if math.fabs(pos2 - pos1) <= max_fragment_length:
                continue
        if fields[2] == '*':
            continue
        if fields[6] == '*':
            continue
        DiscordanrPairs[ID] = (chr1,pos1,chr2,pos2)

    DiscordantPairsList = []
    for ID in DiscordanrPairs.keys():
        (chr1,pos1,chr2,pos2) = DiscordanrPairs[ID]
        DiscordantPairsList.append((chr1,pos1,chr2,pos2))

    outfile = open(outputfilename,'w')

    print len(DiscordantPairsList)

    PairNeighbourhoods = {}
    PointToPointDict = {}

    for ID in DiscordanrPairs.keys():
        (chr1,pos1,chr2,pos2) = DiscordanrPairs[ID]
        if PairNeighbourhoods.has_key(chr1):
            pass
        else:
            PairNeighbourhoods[chr1]={}
            PointToPointDict[chr1]={}
        if PairNeighbourhoods.has_key(chr2):
            pass
        else:
            PairNeighbourhoods[chr2]={}
            PointToPointDict[chr2]={}
        PointToPointDict[chr1][pos1] = (chr2,pos2)
        PointToPointDict[chr2][pos2] = (chr1,pos1)
        Merged = False
        for (N1,N2) in PairNeighbourhoods[chr1].keys():
            if (math.fabs(pos1 + readLength - N2) < merge_radius) or (math.fabs(pos1-N1) < merge_radius):
                newN1,newN2 = (min(pos1,N1),max(pos1 + readLength,N2))
                del PairNeighbourhoods[chr1][(N1,N2)]
                PairNeighbourhoods[chr1][(newN1,newN2)] = ''
                Merged = True
                break
        if not Merged:
            PairNeighbourhoods[chr1][(pos1,pos1 + readLength)] = ''
        Merged = False
        for (N1,N2) in PairNeighbourhoods[chr2].keys():
            if (math.fabs(pos2 + readLength - N2) < merge_radius) or (math.fabs(pos2 - N1) < merge_radius):
                newN1,newN2 = (min(pos2,N1),max(pos2 + readLength,N2))
                del PairNeighbourhoods[chr2][(N1,N2)]
                PairNeighbourhoods[chr2][(newN1,newN2)] = ''
                Merged = True
                break
        if not Merged:
            PairNeighbourhoods[chr2][(pos2,pos2 + readLength)] = ''

    PairNeighbourhoodsList = []
    for chr in PairNeighbourhoods.keys():
        pairs = PairNeighbourhoods[chr].keys()
        regions = []
        for (N1,N2) in pairs:
            if PairNeighbourhoods[chr].has_key((N1,N2)):
                pass
            else:
                continue
            region = (N1,N2)
            for (M1,M2) in pairs:
                if PairNeighbourhoods[chr].has_key((M1,M2)):
                    if (math.fabs(M1-region[0]) < merge_radius) or (math.fabs(M2-region[1]) < merge_radius):
                        region = (min(M1,region[0]),max(M2,region[1]))
                        del PairNeighbourhoods[chr][(M1,M2)]
            regions.append(region)
        PairNeighbourhoods[chr]={}
        for (N1,N2) in regions:
            PairNeighbourhoods[chr][(N1,N2)]=''

    PointoToMergedRegionDict = {}
    for chr in PairNeighbourhoods.keys():
        if PointoToMergedRegionDict.has_key(chr):
            pass
        else:
            PointoToMergedRegionDict[chr]={}
        for (N1,N2) in PairNeighbourhoods[chr].keys():
            for i in range(N1,N2):
                if PointToPointDict[chr].has_key(i):
                    PointoToMergedRegionDict[chr][i] = (N1,N2)

    MergedPairs = {}
    for chr1 in PairNeighbourhoods.keys():
        for (N1_1,N1_2) in PairNeighbourhoods[chr1].keys():
            found = False
            for pos1 in range(N1_1,N1_2):
                if PointToPointDict[chr1].has_key(pos1):
                    found = True
                    (chr2,pos2) = PointToPointDict[chr1][pos1]
                    (N2_1,N2_2) = PointoToMergedRegionDict[chr2][pos2]
                    if MergedPairs.has_key((chr1,N1_1,N1_2,chr2,N2_1,N2_2)):
                        MergedPairs[(chr1,N1_1,N1_2,chr2,N2_1,N2_2)].append(Set((chr1,pos1,chr2,pos2)))
                    elif MergedPairs.has_key((chr2,N2_1,N2_2,chr1,N1_1,N1_2)):
                        MergedPairs[(chr2,N2_1,N2_2,chr1,N1_1,N1_2)].append(Set((chr1,pos1,chr2,pos2)))
                    else:
                        MergedPairs[(chr1,N1_1,N1_2,chr2,N2_1,N2_2)] = []
                        MergedPairs[(chr1,N1_1,N1_2,chr2,N2_1,N2_2)].append(Set((chr1,pos1,chr2,pos2)))
            if not found:
                print 'not found', (N1_1,N1_2)
                print PointToPointDict[chr1]
                print '========='

#    print MergedPairs

    for (chr1,merged_pos1_1,merged_pos1_2,chr2,merged_pos2_1,merged_pos2_2) in MergedPairs.keys():
        outline = chr1 + '\t' + str(merged_pos1_1) + '\t' + str(merged_pos1_2) + '\t' + chr2 + '\t' + str(merged_pos2_1) + '\t' + str(merged_pos2_2) + '\t' + str(len(MergedPairs[(chr1,merged_pos1_1,merged_pos1_2,chr2,merged_pos2_1,merged_pos2_2)]))
        outfile.write(outline + '\n')
 
    outfile.close()

run()

