##################################
#                                #
# Last modified 12/02/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import re
import sys
import string
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
	import psyco
	psyco.full()
except:
	pass

def getSequence(hg,chromosome,start,stop,sense):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    chromosome = chromosome[3:len(chromosome)]
    if sense=='F' or sense=='+':
        sequence = hg.sequence(chromosome,start,stop-start)
    if sense=='R' or sense=='-':
        preliminarysequence = hg.sequence(chromosome,start,stop-start)
        sequence=''
        for i in range(len(preliminarysequence)):
            sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome gtf outfilename [-geneName] [-oId]' % sys.argv[0]
        print 'the -geneName option will create transcript names made of the gene_name attribute and the transcript_id attritbute'
        print 'the -oId will assign the oId attritubte to transcript names'
        print 'if both option are used, the gene_name and the oId will be merged'
        sys.exit(1)

    genome = sys.argv[1]
    inputfilename = sys.argv[2]
    outputfilename = sys.argv[3]

    doGeneName=False
    dooId=False

    if '-geneName' in sys.argv:
        doGeneName=True

    if '-oId' in sys.argv:
        dooId=True

    hg = Genome(genome)

    CodonDict={'GCU':'A', 'GCC':'A', 'GCA':'A', 'GCG':'A',
               'UUA':'L', 'UUG':'L', 'CUU':'L', 'CUC':'L', 'CUA':'L', 'CUG':'L',
               'CGU':'R', 'CGC':'R', 'CGA':'R', 'CGG':'R', 'AGA':'R', 'AGG':'R',
               'AAA':'K', 'AAG':'K',
               'AAU':'N', 'AAC':'N',
               'AUG':'M',
               'GAU':'D', 'GAC':'D',
               'UUU':'F', 'UUC':'F',
               'UGU':'C', 'UGC':'C',
               'CCU':'P', 'CCC':'P', 'CCA':'P', 'CCG':'P',
               'CAA':'Q', 'CAG':'Q',
               'UCU':'S', 'UCC':'S', 'UCA':'S', 'UCG':'S', 'AGU':'S', 'AGC':'S',
               'GAA':'E', 'GAG':'E',
               'ACU':'T', 'ACC':'T', 'ACA':'T', 'ACG':'T',
               'GGU':'G', 'GGC':'G', 'GGA':'G', 'GGG':'G',
               'UGG':'W',
               'CAU':'H', 'CAC':'H',
               'UAU':'Y', 'UAC':'Y',
               'AUU':'I', 'AUC':'I', 'AUA':'I',
               'GUU':'V', 'GUC':'V', 'GUA':'V', 'GUG':'V',
               'START':'AUG',
               'UAA':'STOP',
               'UGA':'STOP',
               'UAG':'STOP'}
    
    outfile = open(outputfilename, 'w')

    listoflines = open(inputfilename)
    TranscriptDict={}
    i=0
    print 'Inputting annotation'
    for line in listoflines:
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        fields=line.split('\t')
        if fields[2]!='exon':
            continue
        if 'transcript_name' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if doGeneName and dooId:
            try:
                transcriptName=fields[8].split(' gene_name "')[1].split('";')[0] + '.' + fields[8].split(' oId "')[1].split('";')[0]
            except:
                transcriptName=fields[8].split(' oId "')[1].split('";')[0]
        else:
            if doGeneName:
                transcriptName=fields[8].split(' gene_name "')[1].split('";')[0] + '.' + fields[8].split(' transcript_id "')[1].split('";')[0]
            if dooId:
                transcriptName=fields[8].split(' oId "')[1].split('";')[0]
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        start=start
        stop=stop+1
        orientation=fields[6]
        transcriptName=(transcriptName,transcriptID,geneName,geneID)
        if TranscriptDict.has_key(transcriptName):
            pass
        else:
            TranscriptDict[transcriptName]={}
            TranscriptDict[transcriptName]['exons']=[]
        TranscriptDict[transcriptName]['chr']=chr
        TranscriptDict[transcriptName]['orientation']=orientation
        TranscriptDict[transcriptName]['exons'].append((start,stop))
        if 'transcript_type' in line:
            transcriptType=fields[8].split(' transcript_type "')[1].split('";')[0]
        else:
            transcriptType='unknown'
        TranscriptDict[transcriptName]['type']=transcriptType
    
    keys=TranscriptDict.keys()
    keys.sort()
    outline='#GeneID\tGeneName\tTranscriptID\tTranscriptName'+'\t'+'transcriptType'+'\t'+'chr'+'\t'+'LeftPos'+','+'RightPos'+'\t'+'orientation'+'\t'+'Start_codon_pos'+','+'Stop_codon_pos'+'\t'+'RNA_length'+'\t'+'protein_length'+'\t'+'protein'
    outfile.write(outline+'\n')
    P=0
    j=0
    print len(keys), 'transcripts found in annotation'
    for IDs in keys:
        (transcriptName,transcriptID,geneName,geneID) = IDs
        j+=1
        if j % 1000 == 0:
            print len(keys)-j, 'transcripts remaining', transcriptName
        sequence=''
        chr=TranscriptDict[IDs]['chr']
        TranscriptDict[IDs]['exons'].sort()
        left=TranscriptDict[IDs]['exons'][0][0]
        right=TranscriptDict[IDs]['exons'][-1][1]
        problematic=False
        if TranscriptDict[IDs]['orientation']=='+':
            for (start,stop) in TranscriptDict[IDs]['exons']:
                try:
                    sequence=sequence+getSequence(hg,chr,start,stop,'+')
                except:
                    problematic=True
        if TranscriptDict[IDs]['orientation']=='-':
            TranscriptDict[IDs]['exons'].reverse()
            for (start,stop) in TranscriptDict[IDs]['exons']:
                try:
                    sequence=sequence+getSequence(hg,chr,start,stop,'-')
                except:
                    problematic=True
        if problematic:
            P+=1
            continue
        sequence=sequence.upper()
        sequence=sequence.replace('T','U')
        ORF=''
        STARTPOS=''
        STOPCODONPOS=''
        i=0
        AUGpositions=[]
        m = re.compile('AUG')
        for mo in m.finditer(sequence):
            AUGpositions.append(mo.start())            
        STOPpositions=[]
        m = re.compile('(UGA|UAA|UAG)')
        for mo in m.finditer(sequence):
            STOPpositions.append(mo.start())            
        done=False
        longestORF=(0,0)
        for StartPos in AUGpositions:
            for StopPos in STOPpositions:
                if ((StopPos-StartPos) % 3) == 0 and StopPos > StartPos:
                    if StopPos-StartPos > longestORF[1]-longestORF[0]:
                        longestORF=(StartPos,StopPos)
                    break
        (STARTPOS,STOPCODONPOS)=longestORF
        if STARTPOS == 0 or STOPCODONPOS == 0:
            pass
        else:
            for i in range(STARTPOS,STOPCODONPOS,3):
                try:
                    ORF=ORF+CodonDict[sequence[i:i+3]]
                except:
                    ORF=ORF+'*'
        outline=geneID+'\t'+geneName+'\t'+transcriptID+'\t'+transcriptName+'\t'+TranscriptDict[IDs]['type']+'\t'+chr+'\t'+str(left)+','+str(right)+'\t'+TranscriptDict[IDs]['orientation']+'\t'+str(STARTPOS)+','+str(STOPCODONPOS)+'\t'+str(len(sequence))+'\t'+str(len(ORF))
        outline=outline+'\t'+ORF
        outfile.write(outline+'\n')
    
    print 'could not retrieve sequence for', P, 'transcripts'

    outfile.close()

run()

