##################################
#                                #
# Last modified 07/20/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf list-of-SAM-files outputfilename [-nomulti] [-nounique]' % sys.argv[0]
        print 'list-of-SAM-files format: label <tab> filename' 
        print 'Note: the script is not designed to deal with multi reads unless the NH: field is present; multi reads will be weighed by their multiplicity' 
        sys.exit(1)
    
    gtf = sys.argv[1]
    SAMlist = sys.argv[2]
    outfilename = sys.argv[3]

    noMulti=False
    if '-nomulti' in sys.argv:
        noMulti=True
        print 'will discard multi reads'
    noUnique=False
    if '-nounique' in sys.argv:
        noUnique=True
        print 'will discard unique reads'

    lineslist = open(SAMlist)
    SAMDict={}
    for line in lineslist:
        fields=line.strip().split('\t')
        label=fields[0]
        file=fields[1]
        SAMDict[label]=file

    ScoreDict={}

    ExonicPosDict={}
    IntronicPosDict={}

    GeneDict={}
    lineslist = open(gtf)
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if GeneDict.has_key(chr):
            pass
        else:
            ExonicPosDict[chr]={}
            IntronicPosDict[chr]={}
            GeneDict[chr]={}
        start=int(fields[3])
        stop=int(fields[4])
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if GeneDict[chr].has_key(geneID):
            pass
        else:
            GeneDict[chr][geneID]=[]
        GeneDict[chr][geneID].append((start,stop))

    print 'finished inputting annotation'

    keys=GeneDict.keys()
    keys.sort()

    for chr in keys:
        print chr
        for geneID in GeneDict[chr].keys():
            coordinates=[]
            for (start,stop) in GeneDict[chr][geneID]:
                coordinates.append(start)
                coordinates.append(stop)
                for i in range(start,stop):
                    ExonicPosDict[chr][i]=''
            for i in range(min(coordinates),max(coordinates)):
                if ExonicPosDict[chr].has_key(i):
                    pass
                else:
                    IntronicPosDict[chr][i]=''

    print 'finished parsing gene models'

    Exonic=0
    Intronic=0
    keys=ExonicPosDict.keys()
    keys.sort()
    for chr in keys:
        Exonic+=len(ExonicPosDict[chr].keys())
        
    keys=IntronicPosDict.keys()
    keys.sort()
    for chr in keys:
        Intronic+=len(IntronicPosDict[chr].keys())
    print 'total exonic space =', Exonic, 'bp'
    print 'total intronic space =', Intronic, 'bp'

    print 'finished introns'

    outfile = open(outfilename, 'w')

    labels=SAMDict.keys()
    labels.sort()
    for label in labels:
        ScoreDict[label]={}
        ScoreDict[label]['Intronic']=0
        ScoreDict[label]['Exonic']=0
        ScoreDict[label]['Intergenic']=0
        lineslist = open(SAMDict[label])
        i=0
        for line in lineslist:
            if line.startswith('@'):
                continue
            i+=1
            if i % 1000000 == 0:
                print label, i, 'lines processed'
            fields=line.strip().split('\t')
            weight=1.0
            if fields[12].startswith('XS:'):
                if noMulti:
                    if fields[13].split('NH:i:')[1] != '1':
                        continue
                if noUnique:
                    if fields[13].split('NH:i:')[1] == '1':
                        continue
                if fields[13].startswith('NH:i'):
                    weight = 1.0/int(fields[13].split('NH:i:')[1])
            else:
                if noMulti:
                    if fields[12].split('NH:i:')[1] != '1':
                        continue
                if noUnique:
                    if fields[12].split('NH:i:')[1] == '1':
                        continue
                if fields[12].startswith('NH:i'):
                    weight = 1.0/int(fields[12].split('NH:i:')[1])
            chr=fields[2]
            start=int(fields[3])
            if ExonicPosDict.has_key(chr):
                pass
            elif IntronicPosDict.has_key(chr):
                pass
            else:
                ScoreDict[label]['Intergenic']+=weight
                continue
            if ExonicPosDict[chr].has_key(start):
                ScoreDict[label]['Exonic']+=weight
            elif IntronicPosDict[chr].has_key(start):
                ScoreDict[label]['Intronic']+=weight
            else:
                ScoreDict[label]['Intergenic']+=weight
        print label, 'Exonic reads:', ScoreDict[label]['Exonic']
        print label, 'Intronic reads:', ScoreDict[label]['Intronic']
        print label, 'Intergenic reads:', ScoreDict[label]['Intergenic']
        
    outline='#Class'
    for label in labels:
        outline=outline+'\t'+label
    outfile.write(outline+'\n')
    outline='Exonic:'
    for label in labels:
        outline=outline+'\t'+str(ScoreDict[label]['Exonic'])
    outfile.write(outline+'\n')
    outline='Intronic:'
    for label in labels:
        outline=outline+'\t'+str(ScoreDict[label]['Intronic'])
    outfile.write(outline+'\n')
    outline='Intergenic:'
    for label in labels:
        outline=outline+'\t'+str(ScoreDict[label]['Intergenic'])
    outfile.write(outline+'\n')
           
    outfile.close()
   
run()
