##################################
#                                #
# Last modified 08/01/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fasta gtf outfilename [-addChrToGTFchrID] [-noPositionalInformation]' % sys.argv[0]
        sys.exit(1)

    fasta = sys.argv[1]
    gtf=sys.argv[2]
    outputfilename = sys.argv[3]
    doPolyA=False
    if '-polyA' in sys.argv:
        doPolyA=True
        tailsize=int(sys.argv[sys.argv.index('-polyA')+1])
        tail=''
        for i in range(tailsize):
            tail=tail+'A'
        print 'will add a polyA tail of ', tailsize, 'nt'

    doAddChr=False
    if '-addChrToGTFchrID' in sys.argv:
        doAddChr=True

    doNoPosInfo=False
    if '-noPositionalInformation' in sys.argv:
        doNoPosInfo=True

    outfile = open(outputfilename, 'w')

    GenomeDict={}

    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    j=0
    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = fields[8].split('gene_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if TranscriptDict.has_key(transcript):
            if TranscriptDict[transcript][0][0] != chr:
                continue
            else:
                pass
        else:
            TranscriptDict[transcript]=[]
        if doAddChr:
            if chr == 'MtDNA':
                chr = 'chrM'
            else:
                chr = 'chr' + chr
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        TranscriptDict[transcript].append((chr,left,right,orientation))

    g=0 
    print 'Found', len(TranscriptDict.keys()), 'transcripts'
    for transcript in TranscriptDict.keys():
        g+=1
        if g % 1000 == 0:
            print g, 'transcripts sequences processed'
        leftEnds=[]
        rightEnds=[] 
        TranscriptDict[transcript].sort()
        (geneID, geneName, transcriptName, transcriptID) = transcript
        chr = TranscriptDict[transcript][0][0]
        left = TranscriptDict[transcript][0][1]
        right = TranscriptDict[transcript][-1][2]
        orientation = TranscriptDict[transcript][0][3]
        if orientation=='+':
            sequence = GenomeDict[chr][left-1:right]
            sense = 'plus_strand'
        if orientation=='-':
            sequence = GenomeDict[chr][left-1:right]
            sense = 'minus_strand'
            sequence = getReverseComplement(sequence)
        if doNoPosInfo:
            outline='>'+geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID + ':pre_mRNA'
        else:
            outline='>'+geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID + ':' + chr+':' + str(left) + '-' + str(right) + '-' + sense + ':pre_mRNA'
        outfile.write(outline+'\n')
        for b in range(0,len(sequence),50):
            outfile.write(sequence[b:min(b+50, len(sequence))] + '\n')

    outfile.close()

run()

