##################################
#                                #
# Last modified 2021/03/31       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T',
           'T':'A',
           'G':'C',
           'C':'G',
           'N':'N',
           'X':'N',
           'a':'t',
           't':'a',
           'g':'c',
           'c':'g',
           'n':'N',
           'x':'N',
           'R':'N',
           'r':'N',
           'M':'N',
           'm':'N',
           'Y':'N',
           'y':'N',
           'S':'N',
           's':'N',
           'K':'N',
           'k':'N',
           'W':'N',
           'w':'N',
           'H':'N',
           'h':'N',
           'V':'N',
           'v':'N',
           'B':'N',
           'b':'N',
           'D':'N',
           'd':'N'
            }
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fasta gtf outfilename [-CDSUTRs] [-polyA length] [-fastaChrFieldID ID] [-addChrToGTFchrID] [-noPositionalInformation] [-KeepUundeterminedStrand] [-TtoU]' % sys.argv[0]
        print '\t use the [-KeepUundeterminedStrand] option if the file contains strands specified with a dot; those will be considered to be on the plus strand if the option is on, otherwise they will be skipped'
        print '\t use the [-CDSUTRs] option if the file does not contain exon definitions but only CDS and UTR entries'
        sys.exit(1)

    fasta = sys.argv[1]
    gtf=sys.argv[2]
    outputfilename = sys.argv[3]
    doPolyA=False
    if '-polyA' in sys.argv:
        doPolyA=True
        tailsize=int(sys.argv[sys.argv.index('-polyA')+1])
        tail=''
        for i in range(tailsize):
            tail=tail+'A'
        print 'will add a polyA tail of ', tailsize, 'nt'

    doCDSUTRs = False
    if '-CDSUTRs' in sys.argv:
        doCDSUTRs = True
        print 'will use CDS and UTR definitions instead of exons'

    doKeepUundeterminedStrand = False
    if '-KeepUundeterminedStrand' in sys.argv:
        doKeepUundeterminedStrand = True

    doAddChr=False
    if '-addChrToGTFchrID' in sys.argv:
        doAddChr=True

    doFastaChrFieldID = False
    if '-fastaChrFieldID' in sys.argv:
        doFastaChrFieldID = True
        FastaChrFieldID = int(sys.argv[sys.argv.index('-fastaChrFieldID')+1])

    doNoPosInfo=False
    if '-noPositionalInformation' in sys.argv:
        doNoPosInfo=True

    doU = False
    if '-TtoU' in sys.argv:
        print 'will output RNA sequence'
        doU = True

    outfile = open(outputfilename, 'w')

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            if doFastaChrFieldID:
                chr = chr.split('\t')[FastaChrFieldID]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    j=0
    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if doCDSUTRs:
            if fields[2] != 'CDS' and fields[2] != '5UTR' and fields[2] != '3UTR':
                continue
        else:
            if fields[2] != 'exon':
                continue
        chr = fields[0]
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]=[]
        if doAddChr:
            if chr == 'MtDNA':
                chr = 'chrM'
            else:
                chr = 'chr' + chr
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        TranscriptDict[transcript].append((geneName,chr,left,right,orientation))

    g=0 
    print 'Found', len(TranscriptDict.keys()), 'transcripts'
    for transcript in TranscriptDict.keys():
        g+=1
        if g % 1000 == 0:
            print g, 'transcripts sequences processed'
        sequence=''
        leftEnds=[]
        rightEnds=[] 
        TranscriptDict[transcript].sort()
        orientation = TranscriptDict[transcript][0][4]
        if orientation == '.':
            if doKeepUundeterminedStrand:
                orientation = '+'
            else:
                continue
        if orientation=='+' or orientation=='F':
            for (geneName,chr,left,right,orientation) in TranscriptDict[transcript]:
                leftEnds.append(left)
                rightEnds.append(right)
                sequence=sequence+ GenomeDict[chr][left-1:right]
            sense='plus_strand'
        if orientation=='-' or orientation=='R':
            for (geneName,chr,left,right,orientation) in TranscriptDict[transcript]:
                leftEnds.append(left)
                rightEnds.append(right)
                sequence=sequence+GenomeDict[chr][left-1:right]
            sense='minus_strand'
            sequence = getReverseComplement(sequence)
        LeftEnd=min(leftEnds)
        RightEnd=max(rightEnds)
        (geneID, geneName, transcriptName, transcriptID) = transcript
        if doNoPosInfo:
            outline='>'+geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID
        else:
            outline='>'+geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID + ':' + chr + ':' + str(LeftEnd) + '-' + str(RightEnd) + '-' + sense
        outfile.write(outline+'\n')
        if doPolyA:
            sequence = sequence+tail
        if doU:
            sequence = sequence.replace('T','U').replace('t','u')
        for b in range(0,len(sequence ),50):
            outfile.write(sequence[b:min(b+50, len(sequence))] + '\n')

    outfile.close()

run()

