##################################
#                                #
# Last modified 03/22/2009       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s rdsfilename gtf_file minAntisenseRPKM outputfilename [-withmulti]  [-cache size] ' % sys.argv[0]

        sys.exit(1)
    
    hitfilename = sys.argv[1]
    gtf=sys.argv[2]
    minAntisenseRPKM=float(sys.argv[3])
    outfilename = sys.argv[4]
    outfile = open(outfilename, 'w')

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doMulti=False
    if '-withmulti' in sys.argv:
        doMulti=True

    hitRDS = readDataset(hitfilename, verbose = True, cache=True)

    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    readlen = int(metadata['readsize'])
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    normalizeBy = (len(hitRDS) + len(hitRDS))/1000000.
    print 'normalizing factor:', normalizeBy 

    linelist = open(gtf)
    exonDict={}
    exonType={}
    exonType['5exon']=0
    exonType['3exon']=0
    exonType['middleexon']=0
    for line in linelist:
        fields=line.strip().split('\t')
        UCSCID=fields[0]
        chr=fields[1]
        numexons=int(fields[7])
        exonStarts = fields[8].split(',')
        exonEnds = fields[9].split(',')
        sense=fields[2]
        if sense=='+':
            antisense='-'
        if sense=='-':
            antisense='+'
        for i in range(numexons):
            exonID=chr+':'+exonStarts[i]+'-'+exonEnds[i]
            if doCollapse:
                if sense == '+':
                    exonID=chr+':'+exonStarts[i]
                if sense == '-':
                    exonID=chr+':'+exonEnds[i]
            if sense == '+' and i == 0:
                type='5exon'
            elif sense == '+' and i == numexons-1:
                type='3exon'
            elif sense == '-' and i == 0:
                type='3exon'
            elif sense == '-' and i == numexons-1:
                type='5exon'
            else:
                type='middleexon'
            exonDict[exonID]=(chr,int(exonStarts[i]),int(exonEnds[i]),sense,antisense,UCSCID,type)

    for exonID in exonDict.keys():
        (chr,rmin,rmax,sense,antisense,UCSCID,type)=exonDict[exonID]
        exonType[type]+=1
    outline='#5 exons: '+str(exonType['5exon'])+'\n'
    outfile.write(outline)
    outline='#3 exons: '+str(exonType['3exon'])+'\n'
    outfile.write(outline)
    outline='#middle exons: '+str(exonType['middleexon'])+'\n'
    outfile.write(outline)
    print 'finished parsing gene annotation' 
    outfile.write('#UCSC_transcript_ID\texonType\tOrientation\tchr\tstart\tstop\tsense\tantisense\n')
    
    i=0
    for exonID in exonDict.keys():
        (chr,rmin,rmax,sense,antisense,UCSCID,type)=exonDict[exonID]
        antisenseCounts=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True,sense=antisense)
        senseCounts=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True,sense=sense)
        antisenseRPKM=antisenseCounts/(((rmax-rmin)/1000.)*normalizeBy)
        if antisenseRPKM < minAntisenseRPKM:
            continue
        senseRPKM=senseCounts/(((rmax-rmin)/1000.)*normalizeBy)
        print exonID,chr,rmin,rmax,sense,antisense,UCSCID,type,senseRPKM,antisenseRPKM
        outline = '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t' % (exonID, chr, rmin, rmax, sense, antisense, UCSCID, type, senseRPKM, antisenseRPKM)
#            outline = '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t' % (genemodel, name, orientation, chr, exonstart, exonend, plusCounts, minusCounts, plusRPKM, minusRPKM, plusRPKM/(plusRPKM + minusRPKM), minusRPKM/(plusRPKM + minusRPKM))
        outfile.write(outline+'\n')

    outfile.close()
   
run()

