##################################
#                                #
# Last modified 08/29/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s BAM chrom.sizes GTF outputfilename [-splices] [-RPKM]' % sys.argv[0]
        print '       BAM file has to be indexed'
        print '       the script needs an NH attribute in all alignments'
        print '       multireads will be weighed by their NH attribute'
        print '       the script will output separate RPMs from unique and multi reads'
        sys.exit(1)

    SAM = sys.argv[1]
    outputfilename = sys.argv[4]
    chrominfo=sys.argv[2]
    GTF=sys.argv[3]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    outfile=open(outputfilename, 'w')

    outline='#gene_id\tgene_name\tchr\tstart\tend\tstrand\ttotal_bp\tunique_reads\tunique_RPM\tmulti_reads\tmulti_RPM'
    outfile.write(outline+'\n')

    print 'examining reads'
    Unique=0
    Multi=0
    SeenDict={}
    SeenTwiceDict={}
    i=0
    samfile = pysam.Samfile(SAM, "rb" )
    for (chr,start,end) in chromInfoList:
        print chr,start,end
        for alignedread in samfile.fetch(chr, start, end):
            i+=1
            if i % 5000000 == 0:
                print str(i/1000000) + 'M alignments processed'
            fields=str(alignedread).split('\t')
            ID=fields[0]
            if alignedread.is_read1:
                ID = ID + '/1'
            if alignedread.is_read2:
                ID = ID + '/2'
            if SeenDict.has_key(ID):
                if SeenTwiceDict.has_key(ID):
                    continue
                SeenTwiceDict[ID]=''
                Unique-=1
                Multi+=1
            else:
                SeenDict[ID]=''
                Unique+=1

    SeenDict=''
    SeenTwiceDict=''

    NormFactor = (Unique + Multi) / 1000000.0

    print 'NormFactor', NormFactor

    GeneDict={}
    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        gene_id = fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name' in fields[8]:
            gene_name = fields[8].split('gene_name "')[1].split('";')[0]
            ID = (gene_id,gene_name)
        else:
            ID = (gene_id,gene_id)
        if GeneDict.has_key(ID):
            pass
        else:
            GeneDict[ID]=[]
        GeneDict[ID].append((chr,left,right,strand))

    print 'finished inputting annotation'

    g=0
    for (gene_id,gene_name) in GeneDict.keys():
        g+=1
        if g % 10000 == 0:
            print g, 'genes processed'
        ID = (gene_id,gene_name)
        GeneDict[ID] = list(Set(GeneDict[ID]))
        coordinates=[]
        BPSpace={}
        for (chr,left,right,strand) in GeneDict[ID]:
            coordinates.append(left)
            coordinates.append(right)
            for i in range(left,right):
                BPSpace[i]=''
        TotalBP=len(BPSpace.keys())
        UniqueReads=0.0
        MultiReads=0.0
        MultiReadsWeight=0
        UniqueExons=[]
        bp=BPSpace.keys()
        bp.sort()
        firstBP = bp[0]
        lastBP = bp[0]
        for i in bp:
            if i == lastBP + 1:
                pass
            else:
                UniqueExons.append((chr,firstBP,lastBP,strand))
                firstBP=i
            lastBP=i
        for (chr,left,right,strand) in UniqueExons:                
            for alignedread in samfile.fetch(chr, left, right):
                fields=str(alignedread).split('\t')
                if "'NH'," in fields[8]:
                    scaleby = float(fields[8].split("'NH', ")[1].split(')')[0])
                else:
                    print 'multireads not specified with the NH tag, exiting'
                    sys.exit(1)
                if scaleby > 1:
                    MultiReadsWeight+=scaleby
                    MultiReads+=1
                else:
                    UniqueReads+=1
        UniqueRPM = UniqueReads/NormFactor
        MultiRPM = MultiReadsWeight/NormFactor
        outline = gene_id + '\t' + gene_name + '\t' + chr + '\t' + str(min(coordinates)) + '\t' + str(max(coordinates)) + '\t' + strand + '\t' + str(TotalBP) + '\t' + str(UniqueReads) + '\t' + str(UniqueRPM) + '\t' + str(MultiReads) + '\t' + str(MultiRPM)
        outfile.write(outline + '\n')

    outfile.close()

run()
