##################################
#                                #
# Last modified 09/03/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import string
from sets import Set
import os
import subprocess

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def getstrand(FLAGfields):

    if 16 in FLAGfields:
        strand = '-'
    else:
        strand = '+'

    return(strand)

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s samtools read1_BAM read2_BAM outfile_prefix [-circJuncsOnly] [-SameStrandOnly] [-NotOnSameStrandOnly]' % sys.argv[0]
        print '\tBAM file should be generated by aligning against the transcriptome with bowtie'
        print '\treadIDs are assumed to be sorted in the same way in both files'
        print '\tif you aling using more than one thread, reads may not be properly sorted, therefore it is advised that multiple samples be aligned in parallel with 1 CPU only rather than aligning each replicate with a lot of CPUs'
        print '\tNote: -SameStrandOnly and -NotOnSameStrandOnly overwrite -circJuncsOnly and can not be true simultaneously'
        sys.exit(1)

    samtools = sys.argv[1]
    BAM1 = sys.argv[2]
    BAM2 = sys.argv[3]
    outfileprefix = sys.argv[4]

    TranscriptToGeneDict={}
    TranscriptToReadDict={}

    outfileGenes = open(outfileprefix+'.genes','w')
    outfileReads = open(outfileprefix+'.reads','w')
    outline = '#readID\tGeneName\tTranscriptName\tRead_1_pos\tRead_2_pos\tDistance\tSameStrand'
    outfileReads.write(outline + '\n')
    outline = '#GeneName\tTotal_Unique_Read_Pairs\tTotal_Unique_Circular_Pairs\tFraction_Unique\tTotal_Multi_Read_Pairs\tTotal_Multi_Circular_Pairs\tFraction_Multi'
    outfileGenes.write(outline + '\n')

    doCircJuncsOnly=False
    if '-circJuncsOnly' in sys.argv:
        doCircJuncsOnly=True

    doSameStrandOnly=False
    if '-SameStrandOnly' in sys.argv:
        doSameStrandOnly=True
        doCircJuncsOnly=False

    doNotOnSameStrandOnly=False
    if '-NotOnSameStrandOnly' in sys.argv:
        doNotOnSameStrandOnly=True
        doCircJuncsOnly=False

    if doNotOnSameStrandOnly and doSameStrandOnly:
        print 'logical error, both -SameStrandOnly and -NotOnSameStrandOnly are true; exiting'
        sys.exit(1)

    GeneReadCountDict={}

    cmd1 = samtools + ' view ' + BAM1
    cmd2 = samtools + ' view ' + BAM2
    p1 = os.popen(cmd1, "r")
    p2 = os.popen(cmd2, "r")
    currentLine1 = p2.readline()
    currentLine2 = p1.readline()
    fields1 = currentLine1.strip().split('\t')
    fields2 = currentLine2.strip().split('\t')
    CurrentID1=fields1[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0]
    CurrentID2=fields2[0].split(' ')[0].split('/1')[0].split('_1:')[0].split('/2')[0].split('_2:')[0]
    if CurrentID1 != CurrentID2:
        print 'files not properly sorted, beginning of file, exiting'
        sys.exit(1)
    CurrentList1 = []
    CurrentList2 = []
    CurrentList1.append(fields1)
    CurrentList2.append(fields2)
    line1 = currentLine1
    line2 = currentLine2
    i=0
    j=0
    c=0
    while line1 != '':
        line1 = p2.readline()
        if line1 == '':
            continue
        i+=1
        if i % 1000000 == 0:
            print str(i/1000000) + 'M alignments processed in end1', str(j/1000000) + 'M alignments processed in end2'
        fields1 = line1.strip().split('\t')
        ID1 = fields1[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0] 
        if ID1 == CurrentID1:
            CurrentList1.append(fields1)
        else:
            while line2 != '':
                j+=1
                line2 = p1.readline()
                fields2 = line2.strip().split('\t')
                ID2=fields2[0].split(' ')[0].split('/2')[0].split('_2:')[0].split('/1')[0].split('_1:')[0]
                if ID2 == CurrentID1:
                    CurrentList2.append(fields2)
                else:
                    GeneTranscriptDict1={}
                    GeneTranscriptDict2={}
                    for fields in CurrentList1:
                        chr = fields[2]
                        if chr == '*':
                            gene = '*'
                            transcript = '*'
                        else:
                            gene = chr.split(':')[0]
                            transcript = chr.split(':')[1]
                        FLAGfields = FLAG(int(fields[1]))
                        pos = int(fields[3])
                        if ':CircJunc:' in chr:
                            pos = ':CircJunc:' + chr.split(':CircJunc:')[1] + '::::' + str(pos)
                        strand = getstrand(FLAGfields)
                        if GeneTranscriptDict1.has_key((gene,transcript)):
                            pass
                        else:
                            GeneTranscriptDict1[(gene,transcript)]=[]
                            GeneTranscriptDict1[(gene,transcript)].append((pos,strand))
                    for fields in CurrentList2:
                        chr = fields[2]
                        if chr == '*':
                            gene = '*'
                            transcript = '*'
                        else:
                            gene = chr.split(':')[0]
                            transcript = chr.split(':')[1]
                        FLAGfields = FLAG(int(fields[1]))
                        pos = int(fields[3])
                        if ':CircJunc:' in chr:
                            pos = ':CircJunc:' + chr.split(':CircJunc:')[1] + '::::' + str(pos)
                        strand = getstrand(FLAGfields)
                        if GeneTranscriptDict2.has_key((gene,transcript)):
                            pass
                        else:
                            GeneTranscriptDict2[(gene,transcript)]=[]
                            GeneTranscriptDict2[(gene,transcript)].append((pos,strand))
                    AlignedGenes = {}
                    for (gene,transcript) in GeneTranscriptDict1.keys():
                        if GeneTranscriptDict2.has_key((gene,transcript)):
                            pass
                        else:
                            continue
                        AlignedGenes[gene] = (1,0)
                        if len(GeneTranscriptDict1[(gene,transcript)]) == 1 and len(GeneTranscriptDict2[(gene,transcript)]) == 1:
                            (pos1,strand1) = GeneTranscriptDict1[(gene,transcript)][0]
                            (pos2,strand2) = GeneTranscriptDict2[(gene,transcript)][0]
                            if isinstance(pos1,str) or isinstance(pos2,str):
                                outline = CurrentID1 + '\t' + gene + '\t' + transcript + '\t' + str(pos1) + '\t' + str(pos2) + '\t' + 'circular' + '\t' + 'circular' 
                                outfileReads.write(outline + '\n')
                                AlignedGenes[gene] = (1,1)
                            elif pos1 > pos2:
                                if doCircJuncsOnly:
                                    continue
                                if strand1 == strand2:
                                    SameStrand = strand1
                                else:
                                    SameStrand = 'no'
                                if doSameStrandOnly and SameStrand == 'no':
                                    continue
                                if doNotOnSameStrandOnly and SameStrand != 'no':
                                    continue
                                outline = CurrentID1 + '\t' + gene + '\t' + transcript + '\t' + str(pos1) + '\t' + str(pos2) + '\t' + str(pos1-pos2)+ '\t' + SameStrand 
                                c+=1
                                outfileReads.write(outline + '\n')
                                AlignedGenes[gene] = (1,1)
                    for gene in AlignedGenes.keys():
                        if GeneReadCountDict.has_key(gene):
                            pass
                        else:
                            GeneReadCountDict[gene] = {}
                            GeneReadCountDict[gene]['unique']  = 0
                            GeneReadCountDict[gene]['multi']  = 0
                            GeneReadCountDict[gene]['circ_unique']  = 0
                            GeneReadCountDict[gene]['circ_multi'] = 0
                    if len(AlignedGenes.keys()) == 1:
                        for gene in AlignedGenes.keys():
                            if AlignedGenes[gene] == (1,1):
                                GeneReadCountDict[gene]['unique'] += 1
                                GeneReadCountDict[gene]['circ_unique'] += 1
                            if AlignedGenes[gene] == (1,0):
                                GeneReadCountDict[gene]['unique'] += 1
                    if len(AlignedGenes.keys()) > 1:
                        for gene in AlignedGenes.keys():
                            if AlignedGenes[gene] == (1,1):
                                GeneReadCountDict[gene]['multi'] += 1
                                GeneReadCountDict[gene]['circ_multi'] += 1
                            if AlignedGenes[gene] == (1,0):
                                GeneReadCountDict[gene]['multi'] += 1
                    CurrentID2 = ID2
                    break
            CurrentID1 = ID1
            if CurrentID1 != CurrentID2:
                 print 'files not properly sorted, middle of file, exiting'
                 print CurrentID1, CurrentID2, i, j
                 sys.exit(1)
            CurrentList1 = []
            CurrentList2 = []
            CurrentList1.append(fields1)
            CurrentList2.append(fields2)

    p1 = ''
    p2 = ''

    outfileReads.close()

    genes = GeneReadCountDict.keys()
    genes.sort()

    for gene in genes:
        unique = GeneReadCountDict[gene]['unique']
        circ_unique = GeneReadCountDict[gene]['circ_unique']
        multi = GeneReadCountDict[gene]['multi']
        circ_multi = GeneReadCountDict[gene]['circ_multi']
        if unique > 0:
            FractionUnique = circ_unique / (unique + 0.0)
        else:
            FractionUnique = 'NaN'
        if multi > 0:
            FractionMulti = circ_multi / (multi + 0.0)
        else:
            FractionMulti = 'NaN'
        outline = gene + '\t' + str(unique) + '\t' + str(circ_unique) + '\t' + str(FractionUnique) + '\t' + str(multi) + '\t' + str(circ_multi) + '\t' + str(FractionMulti)
        outfileGenes.write(outline + '\n')

    outfileGenes.close()
        
run()