##################################
#                                #
# Last modified 2019/04/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import re
import sys
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s sequence.fa ORFfinder.gtf ORFfinder.faa outprefix' % sys.argv[0]
        print '\tThe script will output the longest ORF for each transcript and will reverse complement the transcript sequence if the ORF is on the minus strand'
        sys.exit(1)

    fasta = sys.argv[1]
    gtf = sys.argv[2]
    faa = sys.argv[3]
    outprefix = sys.argv[4]

    SequenceDict = {}
    inputdatafile = open(fasta)
    ID=''
    for line in inputdatafile:
        if line[0]=='>':
            if ID == '':
                ID = line.strip().split('>')[1]
            else:
                sequence = ''.join(sequence)
                SequenceDict[ID]=sequence.upper()
                ID = line.strip().split('>')[1]
            sequence=[]
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[ID]=sequence.upper()

    ORFDict = {}
    linelist = open(gtf)
    for line in linelist:
        fields = line.strip().split('\t')
        ID = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if ORFDict.has_key('ID'):
            pass
        else:
            ORFDict[ID] = []
        ORFlen = right-left
        ORFID = fields[8].split('gene_id "')[1].split('";')[0]
        ORFDict[ID].append((ORFlen,left,right,strand,ORFID))

    print 'finished inputting ORFs'

    NewORFDict = {}
    FAADict = {}

    for ID in ORFDict:
        ORFDict[ID].sort()
        ORFDict[ID].reverse()
        NewORFDict[ID] = ORFDict[ID][0]
        FAADict[ORFDict[ID][0][4]] = 1

    print 'finished filtering ORFs'
    
    outfileORF = open(outprefix + '.faa', 'w')

    FAASeqDict = {}
    inputdatafile = open(faa)
    ID=''
    for line in inputdatafile:
        if line[0]=='>':
            if ID == '':
                ID = line.strip().split('>')[1]
            else:
                sequence = ''.join(sequence)
                FAASeqDict[ID]=sequence.upper()
                ID = line.strip().split('>')[1]
            sequence=[]
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    FAASeqDict[ID]=sequence.upper()

    print 'finished inputting translated sequences'

    for ID in FAASeqDict.keys():
        if FAADict.has_key(ID):
            outline = '>' + ID
            outfileORF.write(outline + '\n')
            outfileORF.write(FAASeqDict[ID] + '\n')

    print 'finished outputting filtered ORFs sequences'

    outfileORF.close()

    outfileGTF = open(outprefix + '.gtf', 'w')

    DoReverseDict = {}

    ORFDict = {}
    linelist = open(gtf)
    for line in linelist:
        fields = line.strip().split('\t')
        ID = fields[0]
        ORFID = fields[8].split('gene_id "')[1].split('";')[0]
        if ORFID == NewORFDict[ID][4]:
            if NewORFDict[ID][3] == '-':
                DoReverseDict[ID] = True
                left = int(fields[3])
                right = int(fields[4])
                ContigLen = len(SequenceDict[ID])
                newleft = ContigLen - right
                newright = ContigLen - left
                outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(newleft)
                outline = outline + '\t' + str(newright) + '\t' + '.\t+\t.\t' + fields[8]
                outfileGTF.write(outline + '\n')
            else:
                outfileGTF.write(line)

    print 'finished outputting filtered ORFs'

    outfileGTF.close()

    outfileSeq = open(outprefix + '.fna', 'w')

    for ID in SequenceDict.keys():
        outline = '>' + ID
        outfileSeq.write(outline + '\n')
        if DoReverseDict.has_key(ID):
            outfileSeq.write(getReverseComplement(SequenceDict[ID]) + '\n')
        else:
            outfileSeq.write(SequenceDict[ID] + '\n')

    print 'finished outputting strand-corrected fasta'

    outfileSeq.close()

run()

