##################################
#                                #
# Last modified 01/14/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s genome knownGene knownToLocusLink rdsfilename windowlength outputfilename [-uniqueModels] [-cache size] [-nomulti] [-gtf filename]' % sys.argv[0]
        sys.exit(1)
    
    genome = sys.argv[1]
    knownGene = sys.argv[2]
    knownToLocusLink = sys.argv[3]
    hitfilename = sys.argv[4]
    window = int(sys.argv[5])
    outfilename = sys.argv[6]

    outfile = open(outfilename, 'w')

    doMulti=True
    if '-nomulti' in sys.argv:
        doMulti=False

    doGTF = False
    if '-gtf' in sys.argv:
        doGTF = True
        gtf =  sys.argv[sys.argv.index('-gtf') + 1]

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doUnique=False
    if '-uniqueModels' in sys.argv:
        doUnique=True
        print 'will only use genes for which a single gene models has been annotated'

    hitRDS = readDataset(hitfilename, verbose = True, cache=True)
    #sqlite default_cache_size is 2000 pages
    if cachePages > hitRDS.getDefaultCacheSize():
        hitRDS.setDBcache(cachePages)
    metadata = hitRDS.getMetadata()
    dataType = metadata['dataType']
    normalizeBy = len(hitRDS)/1000000.

    if doGTF:
        GeneDict={}
        linelist=open(gtf)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            if fields[2]!='exon':
                continue
            geneID=fields[8].split('gene_id "')[1].split('";')[0]
            transcript=fields[8].split('transcript_id "')[1].split('";')[0]
            chr=fields[0]
            left=int(fields[3])
            right=int(fields[4])
            strand=fields[6]
            if GeneDict.has_key(geneID):
                pass
            else:
                GeneDict[geneID]={}
            if GeneDict[geneID].has_key(transcript):
                pass
            else:
                GeneDict[geneID][transcript]=[]
            GeneDict[geneID][transcript].append((chr,left,right,strand))

        keys=GeneDict.keys()
        keys.sort()

        outline='GeneID\tTranscriptID\tchr\tstart\tstop\torientation\t5RPKM\t3RPKM\tratio'
        outfile.write(outline+'\n')

        i=0
        for ID in keys:
            i+=1
            if i % 1000 == 0:
                print i, 'genes processed'
            if doUnique and len(GeneDict[ID]) > 1:
                continue
            transcriptKeys=GeneDict[ID].keys()
            transcriptKeys.sort()
            for transcript in transcriptKeys:
                orientation=GeneDict[ID][transcript][0][3]
                chr=GeneDict[ID][transcript][0][0]
                left=GeneDict[ID][transcript][0][1]
                right=GeneDict[ID][transcript][-1][2]
                rmin=left
                rmax=rmin+window
                value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True)
                leftRPKM=value/(normalizeBy*((window)/1000.0))
                rmax=right
                rmin=rmax-window
                value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True)
                rightRPKM=value/(normalizeBy*(window/1000.0))
                if orientation == 'F' or orientation == '+':
                    RPKM5=leftRPKM
                    RPKM3=rightRPKM
                if orientation == 'R' or orientation == '-':
                    RPKM3=leftRPKM
                    RPKM5=rightRPKM
                ratio=(RPKM5+0.0001)/(RPKM3+0.0001)
                if ratio < 1:
                    ratio = (-1)*(1.0/ratio)
                outline = ID + '\t' + transcript + '\t' + chr + '\t' + str(left) + '\t' + str(right) + '\t' + orientation + '\t' + str(RPKM5) + '\t' + str(RPKM3) + '\t' + str(ratio)
                outfile.write(outline+'\n')
                    
    else:
        geneIDDict = {}
        hg = Genome(genome)
        idb = geneinfoDB()
        geneinfoDict = idb.getallGeneInfo(genome)
        featDict = hg.getallGeneFeatures()
        geneIDs = featDict.keys()
        i=0
        for k in featDict.keys():
            if i % 1000 == 0:
                print len(featDict.keys())-i 
            i+=1
            start=0
            stop=0
            if idb.getGeneInfo((genome,k))==[]:
                name = 'LOC'+str(k)
            else:
                name = idb.getGeneInfo((genome,k))[0]
            geneIDDict[k]=name

        transcriptDict={}
        transcriptToIDDict={}
        inputdatafile = open(knownToLocusLink)
        lineslist = inputdatafile.readlines()
        for line in lineslist:
            fields=line.strip().split('\t')
            UCSCID=fields[0]
            ID=fields[1]
            try:
                name=geneIDDict[ID]
            except:
                print 'problem with', UCSCID, ID, 'skipping'
                continue
            if transcriptDict.has_key(ID):
                transcriptDict[ID]['models'].append(UCSCID)
                transcriptDict[ID]['name']=name
            else:
                transcriptDict[ID]={}
                transcriptDict[ID]['models']=[]
                transcriptDict[ID]['models'].append(UCSCID)
                transcriptDict[ID]['name']=name
            transcriptToIDDict[UCSCID]=ID

        outline='UCSCID\tname\tchr\tstart\tstop\torientation\tmRNAlength\t5RPKM\t3RPKM'
        outfile.write(outline+'\n')

        inputdatafile = open(knownGene)
        lineslist = inputdatafile.readlines()
        for line in lineslist:
            fields=line.strip().split('\t')
            UCSCID=fields[0]
            if transcriptToIDDict.has_key(UCSCID):
                ID=transcriptToIDDict[UCSCID]
                if doUnique and len(transcriptDict[ID]['models'])>1:
                    continue
                else:
                    pass
            else:
                continue        
            chr=fields[1]
            orientation=fields[2]
            exonLefts=fields[8].split(',')
            exonRights=fields[9].split(',')
            length=0
            for i in range(len(exonLefts)-1):
                length=length+int(exonRights[i])-int(exonLefts[i])
            rmin=int(exonLefts[0])
            rmax=rmin+window
            value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True)
            leftRPKM=value/(normalizeBy*((window)/1000.0))
            rmax=int(exonRights[-2])
            rmin=rmax-window
            value=hitRDS.getCounts(chrom=chr, rmin=rmin, rmax=rmax, uniqs=True, multi=doMulti, splices=True, reportCombined=True)
            rightRPKM=value/(normalizeBy*(window/1000.0))
            if orientation == 'F' or orientation == '+':
                RPKM5=leftRPKM
                RPKM3=rightRPKM
            if orientation == 'R' or orientation == '-':
                RPKM3=leftRPKM
                RPKM5=rightRPKM
            name=transcriptDict[ID]['name']
            print UCSCID, name, chr, fields[3], fields[4], orientation, length, RPKM5, RPKM3
            ratio=(RPKM5+0.0001)/(RPKM3+0.0001)
            print UCSCID, name, chr, fields[3], fields[4], orientation, length, RPKM5, RPKM3, ratio
            if ratio < 1:
                ratio = (-1)*(1.0/ratio)
            outline = '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t' % (UCSCID, name, chr, fields[3], fields[4], orientation, length, RPKM5, RPKM3, ratio)
            outfile.write(outline+'\n')

        outfile.close()
   
run()
