##################################
#                                #
# Last modified 12/09/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s genome knownGene_FileName knownToLocusLink_FileName outfilename [-polyA length] [-longest] [-GENCODE gtf] [-cistematic]' % sys.argv[0]
        print '	for GENCODE, only the longest isoform will be returned'
        sys.exit(1)

    genome = sys.argv[1]
    knownGene=sys.argv[2]
    knownToLocusLink=sys.argv[3]
    outputfilename = sys.argv[4]
    doLongest=False
    if '-longest' in sys.argv:
        doLongest=True
    doPolyA=False
    if '-polyA' in sys.argv:
        doPolyA=True
        tailsize=int(sys.argv[sys.argv.index('-polyA')+1])
        tail=''
        for i in range(tailsize):
            tail=tail+'A'
        print 'will add a polyA tail of ', tailsize, 'nt'
    doGENCODE=False
    if '-GENCODE' in sys.argv:
        doGENCODE=True
        GENCODEfile=sys.argv[sys.argv.index('-GENCODE')+1] 
        print 'using GENCODE annotation'
    doCISTEMATIC=False
    if '-cistematic' in sys.argv:
        doCISTEMATIC=True
        print 'using cistematic annotation'
    outfile = open(outputfilename, 'w')

    hg = Genome(genome)

    geneDict={}
    missed=0
    if doGENCODE:
        j=0
        lineslist = open(GENCODEfile)
        GeneDict={}
        for line in lineslist:
            j+=1
            if j % 100000 == 0:
                print j, 'lines processed'
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            if fields[2]!='exon':
                continue
            if 'gene_name "' in fields[8]:
                GeneName=fields[8].split('gene_name "')[1].split('";')[0]
            else:
                GeneName=fields[8].split('gene_id "')[1].split('";')[0]
            TranscriptID=fields[8].split('transcript_id "')[1].split('";')[0]
            if GeneDict.has_key(GeneName):
                pass
            else:
                GeneDict[GeneName]={}
            if GeneDict[GeneName].has_key(TranscriptID):
                pass
            else:
                GeneDict[GeneName][TranscriptID]=[]
            chr=fields[0]
            left=int(fields[3])
            right=int(fields[4])
            orientation=fields[6]
            GeneDict[GeneName][TranscriptID].append((chr,left,right,orientation))
        g=0
        print 'Found', len(GeneDict.keys()), 'genes'
        for GeneName in GeneDict.keys():
            g+=1
            if g % 1000 == 0:
                print g, 'gene sequences processed'
            if len(GeneDict[GeneName])==1:
                longestIsoform=GeneDict[GeneName][GeneDict[GeneName].keys()[0]]
            else:
                longestIsoform=''
                longest=0
                for TranscriptID in GeneDict[GeneName].keys():
                    length=0
                    for (chr,left,right,strand) in GeneDict[GeneName][TranscriptID]:
                        length+=(right-left)
                    if length > longest:
                        longest=length
                        longestIsoform=GeneDict[GeneName][TranscriptID]
            sequence=''
            leftEnds=[]
            rightEnds=[]
            if orientation=='+' or orientation=='F':
                for (chr,left,right,orientation) in longestIsoform:
                    leftEnds.append(left)
                    rightEnds.append(right)
                    try:
                        sequence=sequence+hg.sequence(chr[3:len(chr)],left,right-left)
                    except:
                        for p in range(left,right-left):
                            try:
                                sequence=sequence+hg.sequence(chr[3:len(chr)],p,1)
                            except:
                                sequence=sequence+'N'
                                missed+=1
                sense='plus_strand'
            if orientation=='-' or orientation=='R':
                for (chr,left,right,orientation) in longestIsoform:
                    leftEnds.append(left)
                    rightEnds.append(right)
                try:
                    exonsequence=hg.sequence(chr[3:len(chr)],left-1,right-left+1)
                    sequence=sequence+getReverseComplement(exonsequence)
                except:
                    for p in range(left-1,right-left+1):
                        try:
                            sequence=sequence+getReverseComplement(hg.sequence(chr[3:len(chr)],p,1))
                        except:
                            sequence=sequence+'N'
                            missed+=1
                sense='minus_strand'
            LeftEnd=min(leftEnds)
            RightEnd=max(rightEnds)
            outline='>'+GeneName+':'+chr+':'+str(LeftEnd)+'-'+str(RightEnd)+'-'+sense
            outfile.write(outline+'\n')
            if doPolyA:
                outfile.write(sequence+tail+'\n')
            else:
                print len(sequence)
                outfile.write(sequence+'\n')
    elif doCISTEMATIC:
        j=0
        idb = geneinfoDB()
        geneinfoDict = idb.getallGeneInfo(genome)
        featDict = hg.getallGeneFeatures()
        for k in featDict.keys():
            if j % 1000 == 0:
                print len(featDict.keys())-j 
            j+=1
            if idb.getGeneInfo((genome,k))==[]:
                name = 'LOC'+str(k)
            else:
                name = idb.getGeneInfo((genome,k))[0]
            geneDict[name]={}
            leftPos=[]
            rightPos=[]
            coordinates=[]
            for feature in featDict[k]:
                leftPos.append(int(feature[2]))
                rightPos.append(int(feature[3]))
                coordinates.append((int(feature[2]),int(feature[3])))
            coordinates=list(Set(coordinates))
            coordinates.sort()
            chr= 'chr'+str(featDict[k][0][1])
            orientation=str(featDict[k][0][4])
            GeneStart=min(leftPos)
            GeneEnd=max(rightPos)
            geneDict[name]['sequence']=''
            for (left,right) in coordinates:
                try:
                    geneDict[name]['sequence']=geneDict[name]['sequence']+hg.sequence(chr[3:len(chr)],left,right-left)
                except:
                    for p in range(left,right-left):
                        try:
                            geneDict[name]['sequence']=geneDict[name]['sequence']+hg.sequence(chr[3:len(chr)],p,1)
                        except:
                            geneDict[name]['sequence']=geneDict[name]['sequence']+'N'
                            missed+=1
            if orientation=='F':
                geneDict[name]['sense']='plus_strand'
            if orientation=='R':
                geneDict[name]['sense']='minus_strand'
                geneDict[name]['sequence']=getReverseComplement(geneDict[name]['sequence'])
            outline='>'+name+':'+chr+':'+str(GeneStart)+':'+str(GeneEnd)+'-'+geneDict[name]['sense']
            outfile.write(outline+'\n')
            if doPolyA:
                outfile.write(geneDict[name]['sequence']+tail+'\n')
            else:
                outfile.write(geneDict[name]['sequence']+'\n')
    else:
        idb = geneinfoDB()
        geneinfoDict = idb.getallGeneInfo(genome)
        featDict = hg.getallGeneFeatures()
        geneIDs = featDict.keys()
        i=0
        IDtoNameDict={}
        GeneDict={}
        outfile.write('GeneID\tGeneName\tChr\tStart\tEnd\tOrientation\tRPKM\n')
        for k in featDict.keys():
            if i % 1000 == 0:
                print len(featDict.keys())-i 
            i+=1
            if idb.getGeneInfo((genome,k))==[]:
                name = 'LOC'+str(k)
            else:
                name = idb.getGeneInfo((genome,k))[0]
            IDtoNameDict[str(k)]=name
            GeneDict[name]=[]
        UCSCIDtoIDDict={}
        listoflines = open(knownToLocusLink)
        lineslist = listoflines.readlines()
        for line in lineslist:
            fields=line.strip().split('\t')
            UCSCIDtoIDDict[fields[0]]=fields[1]

        listoflines = open(knownGene)
        lineslist = listoflines.readlines()
        j=0
        for line in lineslist:
            j+=1
            print j
            fields=line.split('\t')
            if UCSCIDtoIDDict.has_key(fields[0]):
                pass
            else:
                continue
            k=UCSCIDtoIDDict[fields[0]]
            if IDtoNameDict.has_key(k):
                name=IDtoNameDict[k]
            else:
                name='LOC'+k
                GeneDict[name]=[]
            sequence=''
            chr=fields[1]
            sense=fields[2]
            GeneStart=fields[3]
            GeneStop=fields[4]
            ExonStarts=fields[8].split(',')
            ExonEnds=fields[9].split(',')
            for i in range(0,len(ExonEnds)):
                if ExonEnds[i]=='':
                    continue
                left=int(ExonStarts[i])
                right=int(ExonEnds[i])
                try:
                    sequence=sequence+hg.sequence(chr[3:len(chr)],left,right-left)
                except:
                    for p in range(left,right-left):
                        try:
                            geneDict[name]['sequence']=geneDict[name]['sequence']+hg.sequence(chr[3:len(chr)],p,1)
                        except:
                            geneDict[name]['sequence']=geneDict[name]['sequence']+'N'
                            missed+=1
            GeneDict[name].append((chr,sense,GeneStart,GeneStop,sequence))
        print 'exon bp missing from CISTEMATIC', missed
        for name in GeneDict.keys():
            if doLongest:
                CurrentMaxLength=0
                current=[]
                for (chr,sense,GeneStart,GeneEnd,sequence) in GeneDict[name]:
                    if len(sequence)>CurrentMaxLength:
                        current=[]
                        current.append((chr,sense,GeneStart,GeneEnd,sequence))
                for (chr,sense,GeneStart,GeneEnd,sequence) in current:
                    outline='>'+name+':'+chr+':'+GeneStart+':'+GeneEnd+'-'+sense
                    outfile.write(outline+'\n')
                    outfile.write(sequence+'\n')
            else:    
                k=0
                for (chr,sense,GeneStart,GeneEnd,sequence) in GeneDict[name]:
                    k+=1
                    isoname=name+'_alt_'+str(k)
                    outline='>'+isoname+':'+chr+':'+GeneStart+':'+GeneEnd+'-'+sense
                    outfile.write(outline+'\n')
                    if doPolyA:
                        outfile.write(geneDict[name]['sequence']+tail+'\n')
                    else:
                        outfile.write(geneDict[name]['sequence']+'\n')

run()

