##################################
#                                #
# Last modified 8/26/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s genome rdsfilename windowlength stepsize outputfilename  [-cache size] ' % sys.argv[0]

        sys.exit(1)
    
    genome = sys.argv[1]
    hitfilename = sys.argv[2]
    window = int(sys.argv[3])
    step = int(sys.argv[4])
    outfilename = sys.argv[5]

    outfile = open(outfilename, 'w')

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])
#    if '-strandedRNA' in sys.argv:
#        doCache = False
    hitRDS = readDataset(hitfilename, verbose = True, cache=True)

    #sqlite default_cache_size is 2000 pages
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    readlen = int(metadata['readsize'])
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    normalizeBy = len(hitRDS) / 1000000.

    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    i=0
    outline='Dataset:' + hitfilename + '\n'
    outfile.write(outline)
    outline='Calculated the average RPMs over the ' + str(window) + 'at the ends of the 5UTRs and 3UTRs with a step of ' + str(step) + 'bp\n'
    UTR5={}
    UTR3={}
    for i in range(window):
        UTR5[i]=0
        UTR3[i]=0

    genes=0
    outfile.write(outline)
    for k in featDict.keys():
        if i % 1000 == 0:
            print len(featDict.keys())-i 
        i+=1
        rmax5=0
        rmax3=0
        rmin5=0
        rmin3=0
        chr='chr'+str(featDict[k][0][1])
        if featDict[k][0][4]=='F':
            if featDict[k][0][0]=='UTR' and featDict[k][len(featDict[k])-1][0]=='UTR':
                if featDict[k][0][3]-featDict[k][0][2]>window and featDict[k][len(featDict[k])-1][3]-featDict[k][len(featDict[k])-1][2]>window:
                   rmin5=featDict[k][0][2]
                   rmax5=featDict[k][0][2]+window
                   rmin3=featDict[k][len(featDict[k])-1][3]-window
                   rmax3=featDict[k][len(featDict[k])-1][3]
                   genes+=1
        if featDict[k][0][4]=='R':
            if featDict[k][0][0]=='UTR' and featDict[k][len(featDict[k])-1][0]=='UTR':
                if featDict[k][0][3]-featDict[k][0][2]>window and featDict[k][len(featDict[k])-1][3]-featDict[k][len(featDict[k])-1][2]>window:
                   rmin3=featDict[k][0][2]
                   rmax3=featDict[k][0][2]+window
                   rmin5=featDict[k][len(featDict[k])-1][3]-window
                   rmax5=featDict[k][len(featDict[k])-1][3]
                   genes+=1
        if rmax5==0 or rmax3==0 or rmin5==0 or rmin3==0:
            continue
        if featDict[k][0][4]=='F':
            for i in range(window):
                if i % step==0:
                   rmin=rmin5+i
                   rmax=rmin+step
                   value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=True, reportCombined=True)
#                    RPM=value/normalizeBy
                   RPM=value
#                   print 'UTR5 F rmin', rmin, 'rmax', rmax, 'i', i, 'i+step', i+step
                   for j in range(rmin-rmin5,rmax-rmin5):
                       UTR5[j]=UTR5[j]+RPM
                   rmin=rmin3+i
                   rmax=rmin+step
                   value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=True, reportCombined=True)
#                    RPM=value/normalizeBy
                   RPM=value
#                   print 'UTR3 F rmin', rmin, 'rmax', rmax, 'i', i, 'i+step', i+step
                   for j in range(rmin-rmin3,rmax-rmin3):
                       UTR3[j]=UTR3[j]+RPM
        if featDict[k][0][4]=='R':
            for i in range(window):
                if i % step==0:
                    rmin=rmax5-step
                    rmax=rmin+step
                    value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=True, reportCombined=True)
#                    RPM=value/normalizeBy
                    RPM=value
#                    print 'UTR5 R rmin', rmin, 'rmax', rmax, 'i', i, 'i+step', i+step
                    for j in range(i,i+step):
                        UTR5[j]=UTR5[j]+RPM
                    rmin=rmax3-step
                    rmax=rmin+step
                    value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=True, reportCombined=True)
#                    RPM=value/normalizeBy
                    RPM=value
#                    print 'UTR3 R rmin', rmin, 'rmax', rmax, 'i', i, 'i+step', i+step
                    for j in range(i,i+step):
                        UTR3[j]=UTR3[j]+RPM
                         
    outline='Total genes analyzed: '+str(genes)+'\n'
    print outline
    outfile.write(outline)
    outfile.write('5UTR dsitrubtion\n')
    UTR5keys=UTR5.keys()
    UTR5keys.sort()
    for i in UTR5keys:
#        outline=str(i)+'\t'+str(UTR5[i]/float(genes))+'\n'
        outline=str(i)+'\t'+str(UTR5[i])+'\n'
        outfile.write(outline)
    outfile.write('\n\n3UTR dsitrubtion\n')
    UTR3keys=UTR3.keys()
    UTR3keys.sort()
    for i in UTR3keys:
#        outline=str(i)+'\t'+str(UTR3[i]/float(genes))+'\n'
        outline=str(i)+'\t'+str(UTR3[i])+'\n'
        outfile.write(outline)
    outfile.close()
   
run()
