##################################
#                                #
# Last modified 6/17/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

from commoncode import *

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome rdsfile otufilename [-TSSupstream bp bins] [-TSSdownstream bp bins] [-genebodybins bins] [-3UTR bpdownstream bins] [-control controlrdsfile] [-zeros number_to_convert_zeros_to] [-cache size]' % sys.argv[0]
        print "\tNote: Genes that are shorter than the -TSSdownstream distance plus the number of genebodybins times the length of TSSdownstream bins will ont be considered\n"
        sys.exit(1)

    genome = sys.argv[1]
    hitfilename = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')
    doZeros=False

    TSSupstreambp = 2000
    TSSupstreambins = 40
    TSSdownstreambp = 1000
    TSSdownstreambins = 20
    genebodybins = 100
    UTRdownstreambp = 5000
    UTRdownstreambins = 50
    cachePages = -1

    if '-TSSupstream' in sys.argv:
        TSSupstreambp = int(sys.argv[sys.argv.index('-TSSupstream') + 1])
        TSSupstreambins = int(sys.argv[sys.argv.index('-TSSupstream') + 2])
    if '-TSSdownstream' in sys.argv:
        TSSdownstreambp = int(sys.argv[sys.argv.index('-TSSdownstream') + 1])
        TSSdownstreambins = int(sys.argv[sys.argv.index('-TSSdownstream') + 2])
    if '-genebodybins' in sys.argv:
        genebodybins = int(sys.argv[sys.argv.index('-genebodybins')+1])
    if '-zeros' in sys.argv:
        doZeros=True
        zeros = float(sys.argv[sys.argv.index('-zeros')+1])
    if '-3UTR' in sys.argv:
        UTRdownstreambp = int(sys.argv[sys.argv.index('-3UTR')+1])
        UTRdownstreambins = int(sys.argv[sys.argv.index('-3UTR')+2])

    if '-nomulti' in sys.argv:
        withMulti = False

    doCache = False

    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    hitRDS = readDataset(hitfilename, verbose = True, cache=doCache)
    
    doControl=False
    if '-control' in sys.argv:
        doControl=True
        controlrdsfilename = sys.argv[sys.argv.index('-control') + 1]
        hitctrlRDS = readDataset(controlrdsfilename, verbose = True, cache=doCache)
        ctrlnormalizeBy = len(hitctrlRDS) / 1000000.

    #sqlite default_cache_size is 2000 pages
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)

    metadata = hitRDS.getMetadata()
    readlen = int(metadata['readsize'])
    dataType = metadata['dataType']
    readlenRange = range(readlen)

    normalizeBy = len(hitRDS) / 1000000.

    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    minLength=TSSdownstreambp + ((0.0 + TSSdownstreambp)/TSSdownstreambins)*genebodybins
    i=0
    for k in featDict.keys():
        if i % 1000 == 0:
            print len(featDict.keys())-i 
        i+=1
        start=0
        stop=0
        if idb.getGeneInfo((genome,k))==[]:
            name = 'LOC'+str(k)
        else:
            name = idb.getGeneInfo((genome,k))[0]
        leftPos=[]
        rightPos=[]
        for feature in featDict[k]:
            leftPos.append(int(feature[2]))
            rightPos.append(int(feature[3]))
        start=min(leftPos)
        stop=max(rightPos)
        if stop-start < minLength:
            print 'gene', name, 'shorter than', minLength, 'not considered'
            continue
        genes[name]={}
        genes[name]['geneID']=k
        genes[name]['name']=name
        genes[name]['chromosome']= 'chr'+str(featDict[k][0][1])
        genes[name]['leftPos']=start
        genes[name]['rightPos']=stop
        genes[name]['orientation']=str(featDict[k][0][4])

    print 'finished parsing gene info'
   
    j=1
    for name in genes.keys():
        if j % 1000 == 0:
            print len(featDict.keys())-j 
        j+=1
        genes[name]['values']=[]
        if math.fabs(genes[name]['leftPos']-genes[name]['rightPos'])<1500:
            continue
        if genes[name]['orientation']=='F':
            start=genes[name]['leftPos']-TSSupstreambp
            stop=genes[name]['leftPos']+TSSdownstreambp
            chr=genes[name]['chromosome']
            current=start
            slide=float(TSSupstreambp)/TSSupstreambins
            for i in range(TSSupstreambins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitctrlRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
            current=genes[name]['leftPos']
            slide=float(TSSdownstreambp)/TSSdownstreambins
            for i in range(TSSdownstreambins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitctrlRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
            current=genes[name]['leftPos']+TSSdownstreambp
            slide=float(genes[name]['rightPos']-current)/genebodybins
            for i in range(genebodybins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitctrlRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
            current=genes[name]['rightPos']
            slide=float(UTRdownstreambp)/UTRdownstreambins
            for i in range(UTRdownstreambins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitctrlRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
        if genes[name]['orientation']=='R':
            start=genes[name]['rightPos']+TSSupstreambp
            stop=genes[name]['rightPos']-TSSdownstreambp
            chr=genes[name]['chromosome']
            current=start
            slide=float(TSSupstreambp)/TSSupstreambins
            for i in range(TSSupstreambins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
            current=genes[name]['rightPos']
            slide=float(TSSdownstreambp)/TSSdownstreambins
            for i in range(TSSdownstreambins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
            current=genes[name]['rightPos']-TSSdownstreambp
            slide=float(current-genes[name]['leftPos'])/genebodybins
            for i in range(genebodybins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
            current=genes[name]['leftPos']
            slide=float(UTRdownstreambp)/UTRdownstreambins
            for i in range(UTRdownstreambins-1):
                if doControl:
                    v1=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    v2=1. + hitRDS.getCounts(chrom=chr, rmin=int(current-slide), rmax=int(current), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((v1/normalizeBy)/(v2/ctrlnormalizeBy))
                else:
                    value=hitRDS.getCounts(chrom=chr, rmin=int(current), rmax=int(current+slide), uniqs=True, multi=True, splices=False, reportCombined=True)
                    current=current+slide
                    genes[name]['values'].append((value/(normalizeBy*slide/1000)))
        line = genes[name]['name']+'\t'+'chr'+str(featDict[k][0][1])+'\t'+str(genes[name]['leftPos'])+'\t'+str(genes[name]['rightPos'])+'\t'+genes[name]['orientation']
        for value in genes[name]['values']:
             if value<=0.0 :
                 value=zeros
             line = line+'\t'+str(value)
        line=line+'\n'
        outfile.write(line)

    outfile.close()

run()
