##################################
#                                #
# Last modified 06/19/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import re
import sys
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome.fa gtf outfilename [-NCBIGFFFNAProk] [-geneName] [-oId] [-CDSinGTF] [-CDSsuppressSTOP] [-0-based] [-splitChrNamesBySpace] [-NonStandardGeneticCode 2|3|4|5|6|9|10|12|13|14|16|21|22|23|24|25] [-skipChr chrM,[etc.]] [-CUB outfilename]' % sys.argv[0]
        print 'the -geneName option will create transcript names made of the gene_name attribute and the transcript_id attritbute'
        print 'the -oId will assign the oId attritubte to transcript names'
        print 'if both option are used, the gene_name and the oId will be merged'
        print 'if the -CDSinGTF option is used, the script will not look for ORF but usse the CDS annotation provided in the GTF file'
        print 'the -CDSsuppressSTOP option can only be used together with the -CDS option; it will not trigger an error when stop codons are encountered and will instead encode them as a * amino acid'
        print 'the -0-based option will shift all coordinate by 1 to the left compared to the normal treatment'
        print 'Nonstandard genetic codes:'
        print '\t2. The Vertebrate Mitochondrial Code'
        print '\t3. The Yeast Mitochondrial Code'
        print '\t4. The Mold, Protozoan, and Coelenterate Mitochondrial Code and the Mycoplasma/Spiroplasma Code'
        print '\t5. The Invertebrate Mitochondrial Code'
        print '\t6. The Ciliate, Dasycladacean and Hexamita Nuclear Code'
        print '\t9. The Echinoderm and Flatworm Mitochondrial Code'
        print '\t10. The Euplotid Nuclear Code'
        print '\t12. The Alternative Yeast Nuclear Code'
        print '\t13. The Ascidian Mitochondrial Code'
        print '\t14. The Alternative Flatworm Mitochondrial Code'
        print '\t16. Chlorophycean Mitochondrial Code'
        print '\t21. Trematode Mitochondrial Code'
        print '\t22. Scenedesmus obliquus Mitochondrial Code'
        print '\t23. Thraustochytrium Mitochondrial Code'
        print '\t24. Pterobranchia Mitochondrial Code'
        print '\t25. Candidate Division SR1 and Gracilibacteria Code'
        print '\tAlternative initiation codons are not considered'
        sys.exit(1)

    fasta = sys.argv[1]
    inputfilename = sys.argv[2]
    outputfilename = sys.argv[3]

    doGeneName=False
    dooId=False

    if '-geneName' in sys.argv:
        doGeneName=True

    if '-oId' in sys.argv:
        dooId=True

    doCDS = False
    if '-CDSinGTF' in sys.argv:
        doCDS = True
        doSupressSTOP = False
        if '-CDSsuppressSTOP' in sys.argv:
            doSupressSTOP = True

    doNCBIGFFFNAProk = False
    if '-NCBIGFFFNAProk' in sys.argv:
        doNCBIGFFFNAProk = True
        doCDS = True

    doSplitChrNamesBySpace = False
    if '-splitChrNamesBySpace' in sys.argv:
        doSplitChrNamesBySpace = True

    doOBased = False
    if '-0-based' in sys.argv:
        doOBased = True

    doSkipChr = False
    if '-skipChr' in sys.argv:
        doSkipChr = True
        ChrToSkip = {}
        fields = sys.argv[sys.argv.index('-skipChr') + 1].split(',')
        for chr in fields:
            ChrToSkip[chr] = ''

    doCUB = False
    if '-CUB' in sys.argv:
        doCUB = True
        CUBfile = sys.argv[sys.argv.index('-CUB') + 1]

    CodonDict={'GCU':'A', 'GCC':'A', 'GCA':'A', 'GCG':'A',
               'UUA':'L', 'UUG':'L', 'CUU':'L', 'CUC':'L', 'CUA':'L', 'CUG':'L',
               'CGU':'R', 'CGC':'R', 'CGA':'R', 'CGG':'R', 'AGA':'R', 'AGG':'R',
               'AAA':'K', 'AAG':'K',
               'AAU':'N', 'AAC':'N',
               'AUG':'M',
               'GAU':'D', 'GAC':'D',
               'UUU':'F', 'UUC':'F',
               'UGU':'C', 'UGC':'C',
               'CCU':'P', 'CCC':'P', 'CCA':'P', 'CCG':'P',
               'CAA':'Q', 'CAG':'Q',
               'UCU':'S', 'UCC':'S', 'UCA':'S', 'UCG':'S', 'AGU':'S', 'AGC':'S',
               'GAA':'E', 'GAG':'E',
               'ACU':'T', 'ACC':'T', 'ACA':'T', 'ACG':'T',
               'GGU':'G', 'GGC':'G', 'GGA':'G', 'GGG':'G',
               'UGG':'W',
               'CAU':'H', 'CAC':'H',
               'UAU':'Y', 'UAC':'Y',
               'AUU':'I', 'AUC':'I', 'AUA':'I',
               'GUU':'V', 'GUC':'V', 'GUA':'V', 'GUG':'V',
               'START':'AUG',
               'UAA':'STOP',
               'UGA':'STOP',
               'UAG':'STOP'}
    STOPCODONREGEXP = '(UGA|UAA|UAG)'

###### http://www.ncbi.nlm.nih.gov/Taxonomy/Utils/wprintgc.cgi#SG6

    if '-NonStandardGeneticCode' in sys.argv:
        NSGC = sys.argv[sys.argv.index('-NonStandardGeneticCode') + 1]
        print 'Will use nonstandard genetic code', NSGC
        if NSGC == '2':
            CodonDict['AGA'] = 'STOP'
            CodonDict['AGG'] = 'STOP'
            CodonDict['AUA'] = 'M'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(AGA|AGG|UAA|UAG)'
        if NSGC == '3':
            CodonDict['AUA'] = 'M'
            CodonDict['CUU'] = 'T'
            CodonDict['CUC'] = 'T'
            CodonDict['CUA'] = 'T'
            CodonDict['CUG'] = 'T'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAA|UAG)'
        if NSGC == '4':
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAA|UAG)'
        if NSGC == '5':
            CodonDict['AGA'] = 'S'
            CodonDict['AGG'] = 'S'
            CodonDict['AUA'] = 'M'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAA|UAG)'
        if NSGC == '6':
            CodonDict['UAA'] = 'Q'
            CodonDict['UAG'] = 'Q'
            STOPCODONREGEXP = '(UGA)'
        if NSGC == '9':
            CodonDict['AAA'] = 'N'
            CodonDict['AGA'] = 'S'
            CodonDict['AGG'] = 'S'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAA|UAG)'
        if NSGC == '10':
            CodonDict['UGA'] = 'C'
            STOPCODONREGEXP = '(UAA|UAG)'
        if NSGC == '12':
            CodonDict['CUG'] = 'S'
        if NSGC == '13':
            CodonDict['AGA'] = 'G'
            CodonDict['AGG'] = 'G'
            CodonDict['AUA'] = 'M'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAA|UAG)'
        if NSGC == '14':
            CodonDict['AAA'] = 'N'
            CodonDict['AGA'] = 'S'
            CodonDict['AGG'] = 'S'
            CodonDict['UAA'] = 'Y'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAG)'
        if NSGC == '16':
            CodonDict['TAG'] = 'L'
        if NSGC == '21':
            CodonDict['TGA'] = 'W'
            CodonDict['ATA'] = 'M'
            CodonDict['AGA'] = 'S'
            CodonDict['AGG'] = 'S'
            CodonDict['AAA'] = 'N'
        if NSGC == '22':
            CodonDict['TCA'] = 'STOP'
            CodonDict['TAG'] = 'L'
        if NSGC == '23':
            CodonDict['TTA'] = 'STOP'
        if NSGC == '24':
            CodonDict['AGA'] = 'S'
            CodonDict['AGG'] = 'K'
            CodonDict['UGA'] = 'W'
            STOPCODONREGEXP = '(UAA|UAG)'

    if doCUB:
        CodonUsageDict = {}
        for COD in CodonDict.keys():
            if COD == 'START':
                continue
            AA = CodonDict[COD]
            if CodonUsageDict.has_key(AA):
                pass
            else:
                CodonUsageDict[AA]={}
            CodonUsageDict[AA][COD] = 0
    
    SequenceDict = {}
    inputdatafile = open(fasta)
    ID=''
    for line in inputdatafile:
        if line[0]=='>':
            if ID == '':
                if doNCBIGFFFNAProk:
                    ID = line.strip().split('>')[1].split('ref|')[1].split('|')[0]
                elif doSplitChrNamesBySpace:
                    ID = line.strip().split('>')[1].split(' ')[0]
                else:
                    ID = line.strip().split('>')[1]
                print ID
            else:
                sequence = ''.join(sequence)
                SequenceDict[ID]=sequence.upper()
                if doNCBIGFFFNAProk:
                    ID = line.strip().split('>')[1].split('ref|')[1].split('|')[0]
                elif doSplitChrNamesBySpace:
                    ID = line.strip().split('>')[1].split(' ')[0]
                else:
                    ID = line.strip().split('>')[1]
            sequence=[]
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[ID]=sequence.upper()

    print 'finished inputting sequence'

    outfile = open(outputfilename, 'w')

    listoflines = open(inputfilename)
    TranscriptDict={}
    i=0
    print 'Inputting annotation'
    for line in listoflines:
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        fields=line.split('\t')
        if doCDS:
            if fields[2] != 'CDS':
                continue
        else:
            if fields[2] != 'exon':
                continue
        chr=fields[0]
        if doSkipChr:
           if ChrToSkip.has_key(chr):
               continue
        if SequenceDict.has_key(chr):
            pass
        else:
            continue
        if doNCBIGFFFNAProk:
            geneID = fields[8].split('ID=')[1].split(';')[0]
            transcriptID = geneID
            geneName = fields[8].split(';product=')[1].split(';')[0]
            transcriptName = geneName
        else:
            if 'transcript_name' in fields[8]:
                transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
            else:
                transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
            transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
            geneID=fields[8].split('gene_id "')[1].split('";')[0]
            if 'gene_name' in fields[8]:
                geneName=fields[8].split('gene_name "')[1].split('";')[0]
            else:
                geneName=fields[8].split('gene_id "')[1].split('";')[0]
            transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
            if doGeneName and dooId:
                try:
                    transcriptName=fields[8].split(' gene_name "')[1].split('";')[0] + '.' + fields[8].split(' oId "')[1].split('";')[0]
                except:
                    transcriptName=fields[8].split(' oId "')[1].split('";')[0]
            else:
                if doGeneName:
                    transcriptName=fields[8].split(' gene_name "')[1].split('";')[0] + '.' + fields[8].split(' transcript_id "')[1].split('";')[0]
                if dooId:
                    transcriptName=fields[8].split(' oId "')[1].split('";')[0]
        start=int(fields[3])
        stop=int(fields[4])
        start=start
        stop=stop+1
        if doOBased:
            start = start - 1
            stop = stop - 1
        orientation=fields[6]
        transcriptName=(transcriptName,transcriptID,geneName,geneID)
        if TranscriptDict.has_key(transcriptName):
            pass
        else:
            TranscriptDict[transcriptName]={}
            TranscriptDict[transcriptName]['exons']=[]
        TranscriptDict[transcriptName]['chr']=chr
        TranscriptDict[transcriptName]['orientation']=orientation
        TranscriptDict[transcriptName]['exons'].append((start,stop))
        if 'transcript_type' in line:
            transcriptType=fields[8].split(' transcript_type "')[1].split('";')[0]
        else:
            transcriptType='unknown'
        TranscriptDict[transcriptName]['type']=transcriptType
    
    keys=TranscriptDict.keys()
    keys.sort()
    outline='#GeneID\tGeneName\tTranscriptID\tTranscriptName'+'\t'+'transcriptType'+'\t'+'chr'+'\t'+'LeftPos'+','+'RightPos'+'\t'+'orientation'+'\t'+'Start_codon_pos'+','+'Stop_codon_pos'+'\t'+'RNA_length'+'\t'+'protein_length'+'\t'+'protein'
    outfile.write(outline+'\n')
    P=0
    j=0
    print len(keys), 'transcripts found in annotation'
    for IDs in keys:
        (transcriptName,transcriptID,geneName,geneID) = IDs
        j+=1
        if j % 1000 == 0:
            print len(keys)-j, 'transcripts remaining', transcriptName
        sequence=''
        chr=TranscriptDict[IDs]['chr']
        TranscriptDict[IDs]['exons'].sort()
        left=TranscriptDict[IDs]['exons'][0][0]
        right=TranscriptDict[IDs]['exons'][-1][1]
        problematic=False
        if TranscriptDict[IDs]['orientation']=='+':
            for (start,stop) in TranscriptDict[IDs]['exons']:
                sequence=sequence + SequenceDict[chr][start:stop]
        if TranscriptDict[IDs]['orientation']=='-':
            TranscriptDict[IDs]['exons'].reverse()
            for (start,stop) in TranscriptDict[IDs]['exons']:
                try:
                    sequence = sequence + getReverseComplement(SequenceDict[chr][start:stop])
                except:
                    problematic=True
        if problematic:
            P+=1
            continue
        sequence = sequence.upper()
        sequence = sequence.replace('T','U')
        ORF=''
#        print transcriptName,transcriptID,geneName,geneID
        if doCDS:
            for i in range(0,len(sequence)-6,3):
#                if i == 0:
#                    print i,i+3,sequence[i:i+3],CodonDict[sequence[i:i+3]],len(sequence),len(sequence)-3
                ISSTOP = False
                if 'N' not in sequence[i:i+3] and 'Y' not in sequence[i:i+3] and  'K' not in sequence[i:i+3] and 'S' not in sequence[i:i+3] and  'W' not in sequence[i:i+3] and 'M' not in sequence[i:i+3] and 'R' not in sequence[i:i+3]:
                    if CodonDict[sequence[i:i+3]] == 'STOP':
                        if doSupressSTOP:
                            print 'problem with stop codon assignment, substituting with X'
                            print geneID, transcriptID, sequence[i:i+3], i, len(sequence), chr
                            ISSTOP = True
                        else:
                            print 'problem with stop codon assignment, exiting'
                            print geneID, transcriptID, sequence[i:i+3], i, len(sequence), chr
                            sys.exit(1)
                try:
                    if ISSTOP:
                        ORF = ORF + 'X'
                    else:
                        ORF = ORF + CodonDict[sequence[i:i+3]]
                        if doCUB:
                            CodonUsageDict[CodonDict[sequence[i:i+3]]][sequence[i:i+3]] += 1
                except:
                    ORF = ORF + 'X'
            (STARTPOS,STOPCODONPOS) = (0,len(sequence)-1)
        else:
            STARTPOS=''
            STOPCODONPOS=''
            i=0
            AUGpositions=[]
            m = re.compile('AUG')
            for mo in m.finditer(sequence):
                AUGpositions.append(mo.start())            
            STOPpositions=[]
            m = re.compile(STOPCODONREGEXP)
            for mo in m.finditer(sequence):
                STOPpositions.append(mo.start())            
            done=False
            longestORF=(0,0)
            for StartPos in AUGpositions:
                for StopPos in STOPpositions:
                    if ((StopPos-StartPos) % 3) == 0 and StopPos > StartPos:
                        if StopPos-StartPos > longestORF[1]-longestORF[0]:
                            longestORF=(StartPos,StopPos)
                        break
            (STARTPOS,STOPCODONPOS)=longestORF
            if STARTPOS == 0 or STOPCODONPOS == 0:
                pass
            else:
                for i in range(STARTPOS,STOPCODONPOS,3):
                    try:
                        ORF=ORF+CodonDict[sequence[i:i+3]]
                        if doCUB:
                            CodonUsageDict[CodonDict[sequence[i:i+3]]][sequence[i:i+3]] += 1
                        if CodonDict[sequence[i:i+3]] == 'STOP':
                            print 'problem with stop codon assignment, exiting'
                            sys.exit(1)
                    except:
                        ORF=ORF+'X'
        outline=geneID+'\t'+geneName+'\t'+transcriptID+'\t'+transcriptName+'\t'+TranscriptDict[IDs]['type']+'\t'+chr+'\t'+str(left)+','+str(right)+'\t'+TranscriptDict[IDs]['orientation']+'\t'+str(STARTPOS)+','+str(STOPCODONPOS)+'\t'+str(len(sequence))+'\t'+str(len(ORF))
        outline=outline+'\t'+ORF
        outfile.write(outline+'\n')
    
    print 'could not retrieve sequence for', P, 'transcripts'

    outfile.close()

    if doCUB:
        outfile = open(CUBfile, 'w')
        outline = '#AA\tcodon\tFrequency\tCounts'
        outfile.write(outline + '\n')
        AAs = CodonUsageDict.keys()
        AAs.sort()
        for AA in AAs:
            total = 0.0
            codons = CodonUsageDict[AA].keys()
            codons.sort()
            for COD in codons:
                total += CodonUsageDict[AA][COD]
            for COD in codons:
#                print AA, COD, CodonUsageDict[AA][COD], total
                if total == 0:
                    outline = AA + '\t' + COD + '\t' + 'nan' + '\t' + str(CodonUsageDict[AA][COD])
                else:
                    outline = AA + '\t' + COD+ '\t' + str(CodonUsageDict[AA][COD]/total) + '\t' + str(CodonUsageDict[AA][COD]) 
                outfile.write(outline + '\n')
        outfile.close()

run()

