##################################
#                                #
# Last modified 01/21/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

##############################################$$
#
#     stalled unfinished
#
################################################
def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome list-of-files outputfilename [-RNA RNArdsfilename regionlength] [-UCSCTSS UCSCTSSradiusbedfile] [-GENCODE GENCODE-annotation-custom-file] [-addreadwhenzero] [-nomulti] [-upstreamStart bp (default 2000)] [-downstreamStop bp (default 2000)] [-stalledPolymerase PolIIrds Ctrlrds minTSSRPKM minTSSEnrichment TSS-vs-GeneBody-enrichment-ratio TSSwindow GeneBodyStart] [-cache size] ' % sys.argv[0]

        sys.exit(1)
    
    genome = sys.argv[1]
    inputfilelist = sys.argv[2]
    outfilename = sys.argv[3]

    outfile = open(outfilename, 'w')

    doGENCODE=False
    if '-GENCODE' in sys.argv:
        doGENCODE=True
        GENCODEfile = sys.argv[sys.argv.index('-GENCODE') + 1]
        print 'will use GENECODE annotation'

    doRNA=False
    if '-RNA' in sys.argv:
        doRNA=True
        RNASeqrdsfilename = sys.argv[sys.argv.index('-RNA') + 1]
        RNAlength = int(sys.argv[sys.argv.index('-RNA') + 2])
        print 'will include RNA values within', RNAlength, 'bp from TSS'

    doUCSC=False
    if '-UCSCTSS' in sys.argv:
        doUCSC=True
        UCSCTSSradiusbedfile = sys.argv[sys.argv.index('-UCSCTSS') + 1]
        print 'will use supplied UCSC TSS bed file '

    doStalled=False
    if '-stalledPolymerase' in sys.argv:
        doStalled=True
        PolIIrdsfilename = sys.argv[sys.argv.index('-stalledPolymerase') + 1]
        Ctrlrdsfilename = sys.argv[sys.argv.index('-stalledPolymerase') + 2]
        minTSSRPKM = float(sys.argv[sys.argv.index('-stalledPolymerase') + 3])
        minTSSEnrichment = float(sys.argv[sys.argv.index('-stalledPolymerase') + 4])
        TSS_vs_GeneBody_enrichment_ratio = float(sys.argv[sys.argv.index('-stalledPolymerase') + 5])
        TSSwindow = int(sys.argv[sys.argv.index('-stalledPolymerase') + 6])
        GeneBodyStart = float(sys.argv[sys.argv.index('-stalledPolymerase') + 7])

    doMulti=True
    if '-nomulti' in sys.argv:
        doMulti=False 
   
    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doAddRead=False
    if '-addreadwhenzero' in sys.argv:
        doAddRead=True

    doUpstreamStart = True
    upstreamStart=2000
    if '-upstreamStart' in sys.argv:
        upstreamStart = int(sys.argv[sys.argv.index('-upstreamStart') + 1])
    doDownstreamStop = True
    downstreamStop=2000
    if '-downstreamStop' in sys.argv:
        downstreamStop = int(sys.argv[sys.argv.index('-downstreamStop') + 1])
    
    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    i=0
    for k in featDict.keys():
        if i % 1000 == 0:
            print len(featDict.keys())-i 
        i+=1
        start=0
        stop=0
        if idb.getGeneInfo((genome,k))==[]:
            name = 'LOC'+str(k)
        else:
            name = idb.getGeneInfo((genome,k))[0]
        leftPos=[]
        rightPos=[]
        for feature in featDict[k]:
            leftPos.append(int(feature[2]))
            rightPos.append(int(feature[3]))
        chr='chr'+str(featDict[k][0][1])
        orientation=str(featDict[k][0][4])
        genes[name]={}
        rmin=min(leftPos)
        rmax=max(rightPos)
        if doUpstreamStart and doDownstreamStop:
            if rmax-rmin <=upstreamStart or rmax-rmin <=downstreamStop:
                continue
            else:
                if orientation=='F':
                    genes[name]['TSS']=min(leftPos)
                    genes[name]['GeneEnd']=min(rightPos)
                    rmin=min(leftPos)-upstreamStart
                    rmax=min(leftPos)+downstreamStop
                if orientation=='R':
                    genes[name]['GeneEnd']=min(leftPos)
                    genes[name]['TSS']=min(rightPos)
                    rmin=max(rightPos)-upstreamStart
                    rmax=max(rightPos)+downstreamStop
        genes[name]['name']=name
        genes[name]['orientation']=orientation
        genes[name]['rmin']=rmin
        genes[name]['rmax']=rmax
        genes[name]['chr']=chr
        genes[name]['outline']=str(k)+'\t'+name+'\t'+chr+'\t'+str(rmin)+'\t'+str(rmax)


    if doGENCODE:
        genes = {}
        listoflines = open(GENCODEfile)
        lineslist = listoflines.readlines()
        for line in lineslist:
            fields=line.strip().split('\t')
            name = fields[0]
            chr=fields[1]
            orientation=fields[4]
            if orientation=='+':
                genes[name]['TSS']=int(fields[2])
                genes[name]['GeneEnd']=int(fields[3])
                rmin=int(fields[2])-upstreamStart
                rmax=int(fields[2])+downstreamStop
            if orientation=='-':
                genes[name]['TSS']=int(fields[3])
                genes[name]['GeneEnd']=int(fields[2])
                rmin=int(fields[3])-downstreamStop
                rmax=int(fields[3])+upstreamStart
            genes[name]={}
            genes[name]['name']=name
            genes[name]['orientation']=orientation
            genes[name]['rmin']=rmin
            genes[name]['rmax']=rmax
            genes[name]['chr']=chr
            genes[name]['outline']=name+'\t'+chr+'\t'+str(rmin)+'\t'+str(rmax)
    elif doUCSC:
        genes = {}
        seen={}
        listoflines = open(UCSCTSSradiusbedfile)
        lineslist = listoflines.readlines()
        nonamecounter=1
        for line in lineslist:
            fields=line.strip().split('\t')
            name = fields[0]
            if name=='' or name=='nonamematch':
                name='nonamematch'+str(nonamecounter)
                nonamecounter+=1
            if seen.has_key(name):
                seen[name]+=1
                name=name+'_altTSS_'+str(seen[name])
            else:
                seen[name]=1
            chr=fields[2]
            start=int(fields[3])
            stop=int(fields[4])
            if doRNA:
                if fields[1]=='+':
                    RNArmin=int(start+((stop-start)/2.0))
                    RNArmax=RNArmin+RNAlength                
                if fields[1]=='-':
                    RNArmax=int(start+((stop-start)/2.0))
                    RNArmin=RNArmax-RNAlength                
            genes[name]={}
            genes[name]['name']=name
            genes[name]['chr']=chr
            genes[name]['rmin']=start
            genes[name]['rmax']=stop
            genes[name]['outline']=name+'\t'+chr+'\t'+str(start)+'\t'+str(stop)
            if doRNA:
                genes[name]['RNArmin']=RNArmin
                genes[name]['RNArmax']=RNArmax
                genes[name]['outline']=name+'\t'+chr+'\t'+str(start)+'\t'+str(stop)+'\t'+fields[1]+'\t'+str(RNArmin)+'\t'+str(RNArmax)

    outline='#GeneName\tchr\tstart\tstop\t'
    if doRNA:
        outline='#GeneName\tchr\tstart\tstop\torientation\tRNAstart\tRNAstop\t'

    listoflines = open(inputfilelist)
    lineslist = listoflines.readlines()
    files = []
    for line in lineslist:
        filename=line.strip()
        files.append(filename)
    files.sort()
    print files
    for hitfilename in files:
        if '.rds' not in hitfilename:
            continue
        if '.rds.log' in hitfilename:
            continue
        print hitfilename
        outline=outline+hitfilename.split('.rds')[0]+'\t'
        hitRDS = readDataset(hitfilename, verbose = True, cache=True)
        #sqlite default_cache_size is 2000 pages
        if cachePages > hitRDS.getDefaultCacheSize():
            hitRDS.setDBcache(cachePages)
        metadata = hitRDS.getMetadata()
        dataType = metadata['dataType']
        normalizeBy = len(hitRDS)/1000000.
        i=0
        for name in genes.keys():
            i+=1
            if i % 1000==0:
                print i
            print name, genes[name]
            chr=genes[name]['chr']
            rmin=genes[name]['rmin']
            rmax=genes[name]['rmax']
            value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=False, reportCombined=True)
            if doAddRead:
                value=value+1
            RPKM=value/(normalizeBy*((rmax-rmin)/1000.0))
            genes[name]['outline']=genes[name]['outline']+'\t'+str(RPKM)
        hitRDS=''
    if doRNA:
        print RNASeqrdsfilename
        hitfilename=RNASeqrdsfilename.split('/')[-1]
        outline=outline+hitfilename.split('.rds')[0]+'\t'
        hitRDS = readDataset(RNASeqrdsfilename, verbose = True, cache=True)
        #sqlite default_cache_size is 2000 pages
        if cachePages > hitRDS.getDefaultCacheSize():
            hitRDS.setDBcache(cachePages)
        metadata = hitRDS.getMetadata()
        dataType = metadata['dataType']
        normalizeBy = len(hitRDS)/1000000.
        i=0
        for name in genes.keys():
            i+=1
            if i % 1000==0:
                print i
            chr=genes[name]['chr']
            rmin=genes[name]['RNArmin']
            rmax=genes[name]['RNArmax']
            value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True)
            if doAddRead:
                value=value+1
            RPKM=value/(normalizeBy*((rmax-rmin)/1000.0))
            genes[name]['outline']=genes[name]['outline']+'\t'+str(RPKM)
        hitRDS=''
    outfile.write(outline + '\n')
    for name in genes.keys():
        outfile.write(genes[name]['outline']+'\n')
    outfile.close()
   
run()
