##################################
#                                #
# Last modified 2018/01/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import math
from sets import Set
import os

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s miRBase.gff3 BAMfilename chrom.sizes 5p_radius 3p_radius outputfilename [-nomulti] [-totalReadNumber number] [-readLength min max] [-uniqueBAM] [-noNH samtools]' % sys.argv[0]
        print 'Note: the script will divide multireads by their multiplicity'
        print '\tuse the uniqueBAM option if the BAM file contains only unique alignments; this will save a lot of memory'
        print '\tuse the -noNH option and supply a path to samtools in order to have the file converted to one that has NH tags'
        print '\tthe radius parameter refers to the tolerance allowed around the 5p or 3p end of the miRNA -- set it to 0 to only include perfect matches at each end'
        print '\tthe script assumes ungapped alignments!!!'
        sys.exit(1)

    gff3 = sys.argv[1]
    SAM = sys.argv[2]
    chromSize = sys.argv[3]
    rad5p = int(sys.argv[4])
    rad3p = int(sys.argv[5])
    outfilename = sys.argv[6]

    chromInfoList=[]
    linelist=open(chromSize)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    noMulti=False
    if '-nomulti' in sys.argv:
        noMulti=True
        print 'will discard multi-read alignments'

    doDirectTRN = False
    if '-totalReadNumber' in sys.argv:
        doDirectTRN = True
        DTRN = int(sys.argv[sys.argv.index('-totalReadNumber')+1])

    doReadLength=False
    if '-readLength' in sys.argv:
        doReadLength=True
        minRL = int(sys.argv[sys.argv.index('-readLength')+1])
        maxRL = int(sys.argv[sys.argv.index('-readLength')+2])
        print 'will only consider reads between', minRL, 'and', maxRL, 'bp length'
        ORLL = 0

    doUniqueBAM = False
    if '-uniqueBAM' in sys.argv:
        print 'will treat all alignments as unique'
        doUniqueBAM = True
        TotalReads = 0
        pass

    samfile = pysam.Samfile(SAM, "rb" )
    try:
        print 'testing for NH tags presence'
        for alignedread in samfile.fetch():
            multiplicity = alignedread.opt('NH')
            print 'file has NH tags'
            break
    except:
        if '-noNH' in sys.argv:
            print 'no NH: tags in BAM file, will replace with a new BAM file with NH tags'
            samtools = sys.argv[sys.argv.index('-noNH')+1]
            BAMpreporcessingScript = sys.argv[0].rpartition('/')[0] + '/bamPreprocessing.py'
            cmd = 'python ' + BAMpreporcessingScript + ' ' + SAM + ' ' + SAM + '.NH'
            os.system(cmd)
            cmd = 'rm ' + SAM
            os.system(cmd)
            cmd = 'mv ' + SAM + '.NH' + ' ' + SAM
            os.system(cmd)
            cmd = samtools + ' index ' + SAM
            os.system(cmd)
        else:
            if doUniqueBAM:
                pass
            else:
                print 'no NH: tags in BAM file, exiting'
                sys.exit(1)

    regionDict={}

    Unique=0
    UniqueSplices=0
    Multi=0
    MultiSplices=0

    if doUniqueBAM and not doReadLength:
        TotalReads = 0
        try:
            for chrStats in pysam.idxstats(SAM):
                fields = chrStats.strip().split('\t')
                chr = fields[0]
                reads = int(fields[2])
                if chr != '*':
                    TotalReads += reads
        except:
            for chrStats in pysam.idxstats(SAM).strip().split('\n'):
                fields = chrStats.strip().split('\t')
                print fields
                chr = fields[0]
                reads = int(fields[2])
                if chr != '*':
                    TotalReads += reads
        UniqueReads = TotalReads
    elif doDirectTRN:
        TotalReads = DTRN
    else:
        MultiplicityDict={}
        UniqueReads = 0
        i=0
        samfile = pysam.Samfile(SAM, "rb" )
        for (chr,start,end) in chromInfoList:
            try:
                for alignedread in samfile.fetch(chr, start, end):
                    i+=1
                    if i % 5000000 == 0:
                        print str(i/1000000) + 'M alignments processed', chr,start,end
                    fields=str(alignedread).split('\t')
                    ID=fields[0]
                    if alignedread.is_read1:
                        if doEnd2Only:
                             continue
                        ID = ID + '/1'
                    if alignedread.is_read2:
                        if doEnd1Only:
                             continue
                        ID = ID + '/2'
                    if doReadLength:
                        if len(alignedread.seq) > maxRL or len(alignedread.seq) < minRL:
                            ORLL += 1
                            continue
                    if doUniqueBAM:
                        TotalReads+=1
                        continue
                    if alignedread.opt('NH') == 1:
                        UniqueReads += 1
                        continue
                    if MultiplicityDict.has_key(ID):
                        MultiplicityDict[ID]+=1
                    else:
                        MultiplicityDict[ID]=1
            except:
                print 'problem with region:', chr, start, end, 'skipping'
        if doReadLength:
            print ORLL, 'alignments outside of read length limits'
        if doUniqueBAM:
            pass
        else:
            TotalReads = UniqueReads + len(MultiplicityDict.keys())

    print 'TotalReads', TotalReads

    normalizeBy = TotalReads/1000000.

    outfile = open(outfilename, 'w')

    outline = '#miRNA\tchr\tleft\tright\tstrand\treads\tRPM'
    outfile.write(outline + '\n')

    i=0
    if gff3.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + gff3
    elif gff3.endswith('.gz'):
        cmd = 'gunzip -c ' + gff3
    else:
        cmd = 'cat ' + gff3
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line[0]=='#':
            continue
        fields = line.strip().split('\t')
        if fields[2] == 'miRNA_primary_transcript':
            continue
        i+=1
        if i % 100 == 0:
            print i, 'miRNAs processed'
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        miRNA = fields[8].split(';Name=')[1].split(';')[0]
        if strand == '+':
            miRpos5 = left
            miRpos3 = right
        if strand == '-':
            miRpos5 = right
            miRpos3 = left
        reads=0
        try:
            for alignedread in samfile.fetch(chr, left, right):
                fields2 = str(alignedread).split('\t')
                if doReadLength:
                    if len(alignedread.seq) > maxRL or len(alignedread.seq) < minRL:
                        continue
                ID = fields2[0]
                if alignedread.is_reverse:
                    s = '-'
                else:
                    s = '+'
                if s != strand:
                    continue
                if s == '+':
                    pos5p = alignedread.pos + 1
                    pos3p = pos5p + len(alignedread.seq)
                if s == '-':
                    pos3p = alignedread.pos + 1
                    pos5p = pos3p + len(alignedread.seq)
                if (pos5p >= miRpos5 - rad5p) and (pos5p <= miRpos5 + rad5p):
                    pass
                else:
                    continue
                if (pos3p >= miRpos3 - rad3p) and (pos3p <= miRpos3 + rad3p):
                    pass
                else:
                    continue
                if doUniqueBAM:
                    reads += 1
                else:
                    if noMulti and alignedread.opt('NH') > 1:
                        continue
                    else:
                        reads += 1./alignedread.opt('NH')
        except:
            print 'problem with region:', chr, left, right, 'assigning 0 value'
            reads=0
        RPM = reads/normalizeBy
        outline = miRNA + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + str(reads) + '\t' + str(RPM)
        outfile.write(outline + '\n')
         
    outfile.close()
   
run()
