##################################
#                                #
# Last modified 9/28/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

from commoncode import *

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s genome radius rdsfile outfilename [-control controlfilename] [-rpkm] [-cache size]' % sys.argv[0]
        sys.exit(1)

    genome = sys.argv[1]
    radius = int(sys.argv[2])
    hitfilename = sys.argv[3]
    outfilename = sys.argv[4]
    doRPKM=False
    if '-rpkm' in sys.argv:
        doRPKM=True

    outfile = open(outfilename, 'w')
   
    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doControl = False
    if '-control' in sys.argv:
        controlfilename = sys.argv[sys.argv.index('-control') + 1]
        doControl = True
        ctrlRDS = readDataset(controlfilename, verbose = True, cache=doCache)

    hitRDS = readDataset(hitfilename, verbose = True, cache=doCache)
    
    #sqlite default_cache_size is 2000 pages
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)
        if doControl:
            ctrlRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    readlen = int(metadata['readsize'])
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    if doControl:
        metadata = ctrlRDS.getMetadata()
        readlen = int(metadata['readsize'])
        dataType = metadata['dataType']
        readlenRange = range(readlen)

    normalizeBy = len(hitRDS) / 1000000.
    if doControl:
        ctrlnormalizeBy = len(ctrlRDS) / 1000000.

    if doControl:
        if doRPKM:
            line = 'GeneName\tstart\tstop\tRPKM\tCtrlRPKM\tratio\n';
            outfile.write(line)
        else:
            line = 'GeneName\tstart\tstop\tRPM\tCtrlRPM\tratio\n';
            outfile.write(line)
    else:
        if doRPKM:
            line = 'GeneName\tstart\tstop\tRPKM\n';
            outfile.write(line)
        else:
            line = 'GeneName\tstart\tstop\tRPM\n';
            outfile.write(line)

    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    i=0
    for k in featDict.keys():
        if i % 1000 == 0:
            print i
        i+=1
        if idb.getGeneInfo((genome,k))==[]:
            name = 'LOC'+str(k)
        else:
            name = idb.getGeneInfo((genome,k))[0]
        leftPos=[]
        rightPos=[]
        for feature in featDict[k]:
            leftPos.append(int(feature[2]))
            rightPos.append(int(feature[3]))
        chr = 'chr'+str(featDict[k][0][1])
        orientation=str(featDict[k][0][4])
        if orientation=='F':
            start=min(leftPos)-radius
            stop=min(leftPos)+radius
        if orientation=='R':
            start=min(rightPos)-radius
            stop=min(rightPos)+radius
        if doRPKM:
            length = abs(stop - start) / 1000.
        else:
            length = 1.
        value=float(hitRDS.getCounts(chrom=chr, rmin=start, rmax=stop, uniqs=True, multi=True, splices=False, reportCombined=True))
        if doControl:
            ctrlvalue=1. + ctrlRDS.getCounts(chrom=chr, rmin=start, rmax=stop, uniqs=True, multi=True, splices=False, reportCombined=True)
            ctrlvalue = ctrlvalue/(ctrlnormalizeBy)/length
            value = value/(normalizeBy)/length
            finalvalue = value/ctrlvalue
            line = name + '\t' + str(start)+'\t'+str(stop)+'\t'+str(value)[0:5] + '\t' + str(ctrlvalue)[0:5] + '\t' + str(finalvalue)[0:5] + '\n';
        else:
            line = name + '\t' + str(start)+'\t'+str(stop)+'\t'+str(value)[0:5] + '\n';
        outfile.write(line)

    outfile.close()

run()
