##################################
#                                #
# Last modified 9/28/2009        # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome rdsfilename outputfilename [-upstreamStart bp] [-downstreamStop bp] [-control controlrdsfilename] [-cache size] ' % sys.argv[0]
        print "\tNote: if you use the -control option, you will get the ratio of control over ChIP over the entire region, which is not normalized for gene length\n"
        print "\tNote: if you use the -upstreamStart option, you will only get genes longer than the specifed basepairs + 500bp\n"

        sys.exit(1)
    
    genome = sys.argv[1]
    hitfilename = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])
#    if '-strandedRNA' in sys.argv:
#        doCache = False
    hitRDS = readDataset(hitfilename, verbose = True, cache=True)

    doStrandedRNA=False
    if '-strandedRNA' in sys.argv:
        doStrandedRNA=True
    doUpstreamStart = False
    if '-upstreamStart' in sys.argv:
        doUpstreamStart=True
        upstreamStart = int(sys.argv[sys.argv.index('-upstreamStart') + 1])
        print 'upstreamStart', upstreamStart 
    doDownstreamStop = False
    if '-downstreamStop' in sys.argv:
        doDownstreamStop=True
        downstreamStop = int(sys.argv[sys.argv.index('-downstreamStop') + 1])
        print 'downstreamStop', downstreamStop
    
    doControl=False
    if '-control' in sys.argv:
        doControl=True
        controlrdsfilename = sys.argv[sys.argv.index('-control') + 1]
        hitctrlRDS = readDataset(controlrdsfilename, verbose = True, cache=doCache)
        ctrlnormalizeBy = len(hitctrlRDS) / 1000000.

    #sqlite default_cache_size is 2000 pages
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    readlen = int(metadata['readsize'])
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    normalizeBy = len(hitRDS) / 1000000.

    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    i=0
    outfile.write('GeneID\tGeneName\tChr\tStart\tEnd\tOrientation\tRPKM\n')
    for k in featDict.keys():
        if i % 1000 == 0:
            print len(featDict.keys())-i 
        i+=1
        start=0
        stop=0
        if idb.getGeneInfo((genome,k))==[]:
            name = 'LOC'+str(k)
        else:
            name = idb.getGeneInfo((genome,k))[0]
        genes[name]={}
        leftPos=[]
        rightPos=[]
        for feature in featDict[k]:
            leftPos.append(int(feature[2]))
            rightPos.append(int(feature[3]))
        chr= 'chr'+str(featDict[k][0][1])
        orientation=str(featDict[k][0][4])
        rmin=min(leftPos)
        rmax=max(rightPos)
        if doUpstreamStart:
            if rmax-rmin <=math.fabs(upstreamStart)+500:
                continue
            else:
                if orientation=='F':
                    rmin=min(leftPos)-upstreamStart
                    if doDownstreamStop:
                        rmax=min(max(rightPos),rmin+downstreamStop)
                    else:
                        rmax=max(rightPos)
                if orientation=='R':
                    rmax=max(rightPos)+upstreamStart
                    if doDownstreamStop:
                        rmin=max(min(leftPos),rmax-downstreamStop)
                    else:
                        rmin=min(leftPos)
        if doControl:
            v1=1. + hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=False, reportCombined=True)
            v2=1. + hitctrlRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=False, reportCombined=True)
            RPKM=(v1/normalizeBy)/(v2/ctrlnormalizeBy)
        else:
            value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=True, splices=False, reportCombined=True)
            RPKM=value/(normalizeBy*((rmax-rmin)/1000.0))
        outline = '%s\t%s\t%s\t%s\t%s\t%s\t%s\t' % (k, name, chr, rmin, rmax, orientation, RPKM)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
