##################################
#                                #
# Last modified 2017/06/12       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import numpy

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s BAMfilename chrom.sizes GTF max_cluster_size merge_distance max_singleton_size outputfilename [-nomulti] [-mismatchesMD M] [-mismatches M] [-readLength min max] [-chr chrN1(,chrN2....)] [-uniqueBAM]' % sys.argv[0]
        print '\tUse the -mismatches option to specify the maximum number of mismatches allowed for an alignment to be considered; use the -mimatchesMD option is mismatches are specified with the MD special tag'
        print '\tThe script assumes paired-end data'
        sys.exit(1)
    
    BAM = sys.argv[1]
    chrominfo=sys.argv[2]
    GTF = sys.argv[3]
    MaxClusterSize = int(sys.argv[4])
    MergDist = int(sys.argv[5])
    MaxSingletonSize = int(sys.argv[6])
    outfilename = sys.argv[7]

    print MaxClusterSize, MergDist, MaxSingletonSize

    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    doReadLength=False
    if '-readLength' in sys.argv:
        doReadLength=True
        minRL = int(sys.argv[sys.argv.index('-readLength')+1])
        maxRL = int(sys.argv[sys.argv.index('-readLength')+2])
        print 'will only consider reads between', minRL, 'and', maxRL, 'bp length'

    doMaxMMMD=False
    if '-mismatchesMD' in sys.argv:
        doMaxMMMD=True
        maxMM = int(sys.argv[sys.argv.index('-mismatchesMD')+1])
        print 'Will only consider alignments with', maxMM, 'or less mismatches'

    doMaxMM=False
    if '-mismatches' in sys.argv:
        doMaxMM=True
        maxMM = int(sys.argv[sys.argv.index('-mismatches')+1])
        print 'Will only consider alignments with', maxMM, 'or less mismatches'

    doChrSubset=False
    if '-chr' in sys.argv:
        doChrSubset=True
        WantedChrDict={}
        for chr in sys.argv[sys.argv.index('-chr')+1].split(','):
            WantedChrDict[chr]=''

    noMulti=False
    if '-nomulti' in sys.argv:
        print 'will only consider unique alignments'
        noMulti=True

    doUniqueBAM = False
    if '-uniqueBAM' in sys.argv:
        print 'will treat all alignments as unique'
        doUniqueBAM = True
        TotalReads = 0
        pass

    TotalNumberRead=0

    samfile = pysam.Samfile(BAM, "rb" )

    if doUniqueBAM and not doReadLength and not doMaxMMMD and not doMaxMM:
        TotalNumberRead = 0
        try:
            for chrStats in pysam.idxstats(BAM):
                fields = chrStats.strip().split('\t')
                chr = fields[0]
                reads = int(fields[2])
                if chr != '*':
                    TotalNumberRead += reads
        except:
            for chrStats in pysam.idxstats(BAM).strip().split('\n'):
                fields = chrStats.strip().split('\t')
                chr = fields[0]
                reads = int(fields[2])
                if chr != '*':
                    TotalNumberRead += reads
    else:
        RN=0
        for (chr,start,end) in chromInfoList:
            if doChrSubset:
                if WantedChrDict.has_key(chr):
                    pass
                else:
                    continue
            try:
                for alignedread in samfile.fetch(chr, 0, 100):
                    a='b'
            except:
                print 'region', chr,start,end, 'not found in bam file, skipping'
                continue
            currentPos=0
            for alignedread in samfile.fetch(chr, start, end):
                RN+=1
                if RN % 5000000 == 0:
                    print 'counting total number of reads', str(RN/1000000) + 'M alignments processed', chr, currentPos, end
                if alignedread.is_proper_pair:
                    pass
                else:
                    continue
                if alignedread.is_read1:
                    pass
                else:
                    continue
                fields=str(alignedread).split('\t')
                FLAGfields = FLAG(int(fields[1]))
                if 128 in FLAGfields:
                    continue
                if doReadLength:
                    if len(alignedread.seq) > maxRL or len(alignedread.seq) < minRL:
                        continue
                if doMaxMM:
                    mismatches = 0
                    for (m,bp) in alignedread.cigar:
                        if m == 8:
                            mismatches+=1
                    if mismatches > maxMM:
                        continue
                if doMaxMMMD:
                    MM = alignedread.opt('MD')
                    mismatches = 0
                    if MM.isdigit():
                        pass
                    else:
                        for s in range(len(MM)):
                            if MM[s].isalpha():
                                mismatches+=1
                    if mismatches > maxMM:
                        continue
                if doUniqueBAM:
                    TotalNumberRead+=1
                    continue
                try:
                    multiplicity = alignedread.opt('NH')
                except:
                    print 'no NH: tags in BAM file, exiting'
                    sys.exit(1)
                if noMulti and multiplicity > 1:
                    continue
                TotalNumberRead += 1.0/multiplicity

    TotalNumberRead = round(TotalNumberRead)

    print 'found', TotalNumberRead, 'reads'
    normFactor = TotalNumberRead/1000000.
    print 'RPM normalization Factor =', normFactor

    outfile = open(outfilename, 'w')
    outline = '#chr\tleft\tright\tstrand\tpeak\tcounts\tRPM\tSI\tgenes'
    outfile.write(outline + '\n')

    RN=0
    for (chr,start,end) in chromInfoList:
        if doChrSubset:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        GTFcovDict = {}
        GTFcovDict['+'] = {}
        GTFcovDict['-'] = {}
        linelist = open(GTF)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            if fields[0] != chr:
                continue
            if fields[2] != 'exon':
                continue
            left = int(fields[3])            
            right = int(fields[4])
            strand = fields[6]
            geneID=fields[8].split('gene_id "')[1].split('"')[0]
            if 'gene_name' in fields[8]:
                geneName = fields[8].split('gene_name "')[1].split('"')[0]
            else:
                geneName = geneID
            geneID = geneID + '|' + geneName
            for i in range(left,right):
                if GTFcovDict[strand].has_key(i):
                    pass
                else:
                    GTFcovDict[strand][i] = []
                GTFcovDict[strand][i].append(geneID)
        print 'finished parsing GTF', chr
        coverageDict = {}
        coverageDict['+'] = {}
        coverageDict['-'] = {}
        if doChrSubset:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            print 'region', chr,start,end, 'not found in bam file, skipping'
            continue
        currentPos=0
        for alignedread in samfile.fetch(chr, start, end):
            RN+=1
            if RN % 5000000 == 0:
                print str(RN/1000000) + 'M alignments processed', chr, currentPos, end
            if doReadLength:
                if len(alignedread.seq) > maxRL or len(alignedread.seq) < minRL:
                    continue
            if alignedread.is_proper_pair:
                pass
            else:
                continue
            if alignedread.is_read1:
                pass
            else:
                continue
            fields=str(alignedread).split('\t')
            FLAGfields = FLAG(int(fields[1]))
            if 128 in FLAGfields:
                continue
            ID = fields[0]
            if doMaxMM:
                mismatches = 0
                for (m,bp) in alignedread.cigar:
                    if m == 8:
                        mismatches+=1
                if mismatches > maxMM:
                    continue
            if doMaxMMMD:
                MM = alignedread.opt('MD')
                mismatches = 0
                if MM.isdigit():
                    pass
                else:
                    for s in range(len(MM)):
                        if MM[s].isalpha():
                            mismatches+=1
                if mismatches > maxMM:
                    continue
            if doUniqueBAM:
                multiplicity = 1
            else:
                multiplicity = alignedread.opt('NH')
            if noMulti and multiplicity > 1:
                continue
            scaleby = 1.0/multiplicity
            if 16 in FLAGfields:
                s = '-'
            else:
                s = '+'
            currentPos=alignedread.pos
            matePos = int(fields[7])
            matchingGenes = []
            if GTFcovDict[s].has_key(matePos):
                matchingGenes = GTFcovDict[s][matePos]
            if s == '+':
                if coverageDict[s].has_key(currentPos):
                    pass
                else:
                    coverageDict[s][currentPos] = {}
                    coverageDict[s][currentPos]['counts'] = 0 
                    coverageDict[s][currentPos]['genes'] = {}
                coverageDict[s][currentPos]['counts'] += scaleby
                for geneID in matchingGenes:
                    coverageDict[s][currentPos]['genes'][geneID] = 1
            if s == '-':
                endPos=currentPos
                for (m,bp) in alignedread.cigar:
                    if m == 0:
                        endPos+=bp
                    elif m == 2:
                        endPos+=bp
                    elif m == 3:
                        endPos+=bp
                    else:
                        continue
                if coverageDict[s].has_key(endPos):
                    pass
                else:
                    coverageDict[s][endPos] = {}
                    coverageDict[s][endPos]['counts'] = 0 
                    coverageDict[s][endPos]['genes'] = {}
                coverageDict[s][endPos]['counts'] += scaleby
                for geneID in matchingGenes:
                    coverageDict[s][endPos]['genes'][geneID] = 1
        for s in coverageDict.keys():
            posKeys = coverageDict[s].keys()
            posKeys.sort()
            if len(posKeys) == 0:
                continue
            cluster = [posKeys[0]]
            for i in range(1,len(posKeys)):
                if posKeys[i] - cluster[-1] < MergDist:
                    cluster.append(posKeys[i])
                    continue
                else:
                    if len(cluster) == 1:
#                        print '.......', coverageDict[s][cluster[0]], MaxSingletonSize, coverageDict[s][cluster[0]] <= MaxSingletonSize
                        if coverageDict[s][cluster[0]]['counts'] <= MaxSingletonSize:
                            cluster = [posKeys[i]]
                            continue
                    outline = chr + '\t' + str(cluster[0]) + '\t' + str(cluster[-1] + 1) + '\t' + s
                    peak = cluster[0]
                    clustercounts = 0.0
                    for pos in cluster:
                        if coverageDict[s][pos]['counts'] > coverageDict[s][peak]['counts']:
                            peak = pos
                        clustercounts += coverageDict[s][pos]['counts']
                    outline = outline + '\t' + str(peak) + '\t' + str(clustercounts) +  '\t' + str(clustercounts/normFactor)
                    SI = 2
                    genes = {}
                    for pos in cluster:
                        p = coverageDict[s][pos]['counts']/clustercounts
                        SI += p*numpy.log2(p)
                        for geneID in coverageDict[s][pos]['genes'].keys():
                            genes[geneID] = 1
                    outline = outline + '\t' + str(SI) + '\t'
                    for geneID in genes.keys():
                        outline = outline + geneID + ','
                    outfile.write(outline[0:-1] + '\n')
                    cluster = [posKeys[i]]
# last cluster
#            outline = chr + '\t' + str(cluster[0]) + '\t' + str(cluster[-1] + 1) + '\t' + s
#            peak = cluster[0]
#            clustercounts = 0.0
#            for pos in cluster:
#                if coverageDict[s][pos]['counts'] > coverageDict[s][peak]['counts']:
#                    peak = pos
#                    clustercounts += coverageDict[s][pos]['counts']
#            outline = outline + '\t' + str(peak) + '\t' + str(clustercounts) +  '\t' + str(clustercounts/normFactor)
#            SI = 2
#            genes = {}
#            for pos in cluster:
#                p = coverageDict[s][pos]['counts']/clustercounts
#                SI += p*numpy.log2(p)
#                for geneID in coverageDict[s][pos]['genes'].keys():
#                    genes[geneID] = 1
#            outline = outline + '\t' + str(SI) + '\t'
#            for geneID in genes.keys():
#                outline = outline + geneID + ','
#            outfile.write(outline[0:-1] + '\n')
#            cluster = [posKeys[i]]

    outfile.close()
            
run()
