##################################
#                                #
# Last modified 08/01/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
from sets import Set

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s transcript_FPKM_table NumberReadPairs transcriptome-fasta-directory premRNA-fasta-directory IntronicFraction path_to_GemReads-modified GTF outprefix' % sys.argv[0]
        print '\t transcriptome-fasta-directory should contain one separate fasta file for each transcript; those should be labelled as follows: ZZZ3:ENSG00000036549.8:ZZZ3-002:ENST00000370798.1.fa'
        print '\t premRNA-fasta-directory should contain one separate fasta file for each premRNA; those should be labelled as follows: ZZZ3:ENSG00000036549.8:ZZZ3-002:ENST00000370798.1:pre_mRNA.fa'
        print '\t the transcript_FPKM_table files should contain the specified FPKM value (for transcripts only, the pre-mRNA reads will be calculated based on those)'
        print '\t transcript_FPKM_table format: #GeneID\tGeneName\tGeneFPKM\tTranscriptID\tTranscriptName\tTranscriptFPKM\tFMI'
        sys.exit(1)

    FPKMtable = sys.argv[1]
    N = int(sys.argv[2])
    mRNAfastaDir = sys.argv[3]
    premRNAfastaDir = sys.argv[4]
    IF = float(sys.argv[5])
    read_length = int(sys.argv[6])
    fragment_length = int(sys.argv[7])
    frag_length_stdev = int(sys.argv[8])
    GemReads = sys.argv[9]
    GTF = sys.argv[10]
    outprefix = sys.argv[11]

    TranscriptDict = {}

    linelist = open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptID, transcriptName)
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]={}
            TranscriptDict[transcript]['exons']=[]
        TranscriptDict[transcript]['exons'].append((chr,start,stop,strand))

    print 'finished parsing GTF file'

    for transcript in TranscriptDict.keys():
        chromosomes = []
        for (chr,start,stop,strand) in TranscriptDict[transcript]['exons']:
            chromosomes.append(chr)
        chromosomes = list(Set(chromosomes))
        if len(chromosomes) > 1:
            del TranscriptDict[transcript]

    TotalFPKM = 0
    TotalFPM = 0
    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[0]
        geneName = fields[1]
        transcriptID = fields[3]
        transcriptName = fields[4]
        FPKM = float(fields[5])
        TotalFPKM += FPKM
        transcript = (geneID, geneName, transcriptID, transcriptName)
        length = 0
        for (chr,start,stop,strand) in TranscriptDict[transcript]['exons']:
            length += (math.fabs(stop-start))
        lengthK = length/1000.
        FPM = FPKM*lengthK
        TotalFPM += FPM
        if TranscriptDict.has_key(transcript):
            TranscriptDict[transcript]['FPKM'] = FPKM
        else:
            print transcript

    print 'finished parsing simulated FPKMs', TotalFPKM, TotalFPM

    outfile_rescaled_FPKM = open(outprefix + '.rescaled_FPKM', 'w')
    outline = '#geneName\tgeneID\ttranscriptName\ttranscriptID\tFPM\tlength\tFPKM\tIF_reads\tEF_reads'
    outfile_rescaled_FPKM.write(outline + '\n')
    outfileFASTQ1 = open(outprefix + '.read1.fastq', 'w')
    outfileFASTQ2 = open(outprefix + '.read2.fastq', 'w')

    s = 0
    for transcript in TranscriptDict.keys():
        if TranscriptDict[transcript].has_key('FPKM'):
            pass
        else:
            continue
        FPKM = TranscriptDict[transcript]['FPKM']
        length = 0
        for (chr,start,stop,strand) in TranscriptDict[transcript]['exons']:
            length += (math.fabs(stop-start))
        lengthK = length/1000.
        FPM = FPKM*lengthK
        rescaled_FPM = FPM/(TotalFPM/1000000.)
        rescaled_FPKM = FPKM/(TotalFPM/1000000.)
        TranscriptDict[transcript]['rescaled_FPKM'] = rescaled_FPKM
        reads = rescaled_FPM*(N/1000000.)
        if reads == 0:
            continue
        IFreads = int(IF*reads)
        EFreads = int((1-IF)*reads)
        (geneID, geneName, transcriptID, transcriptName) = transcript
        outline = geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + str(rescaled_FPKM) + '\t' + str(length) + '\t' + str(rescaled_FPM) + '\t' + str(IFreads) + '\t' + str(EFreads)
        outfile_rescaled_FPKM.write(outline + '\n')
        if EFreads >= 1:
            transcript_file_name = geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID+ '.fa'
            cmd = 'python ' + GemReads + ' -r ' + mRNAfastaDir + '/' + transcript_file_name + ' -m ' + GemReads.split('GemReads-modified.py')[0] + 'models/ill100v5_p.gzip -p -q 33 -n ' + str(EFreads) + ' -l ' + str(read_length) + ' -u ' + str(fragment_length) + ' -s ' + str(frag_length_stdev) + ' -o ' + outprefix + '.temp'
            os.system(cmd)
#            cmd = 'wc -l ' + outprefix + '.temp_fir.fastq'
#            os.system(cmd)
#            print outprefix + '.temp_fir.fastq'
            linelist = open(outprefix + '.temp_fir.fastq')
            for line in linelist:
                outfileFASTQ1.write(line)
            linelist = open(outprefix + '.temp_sec.fastq')
            for line in linelist:
                outfileFASTQ2.write(line)
        if IFreads >= 1:
            pre_mRNA_file_name = geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID + ':' + 'pre_mRNA.fa'
            cmd = 'python ' + GemReads + ' -r ' + premRNAfastaDir + '/' + pre_mRNA_file_name + ' -m ' + GemReads.split('GemReads-modified.py')[0] + 'models/ill100v5_p.gzip -p -q 33 -n ' + str(IFreads) + ' -l ' + str(read_length) + ' -u ' + str(fragment_length) + ' -s ' + str(frag_length_stdev) + ' -o ' + outprefix + '.temp'
            os.system(cmd)
            linelist = open(outprefix + '.temp_fir.fastq')
            for line in linelist:
                outfileFASTQ1.write(line)
            linelist = open(outprefix + '.temp_sec.fastq')
            for line in linelist:
                outfileFASTQ2.write(line)
        if EFreads >= 1 or IFreads >= 1 :
            cmd = 'rm ' + outprefix + '.temp_sec.fastq'
            os.system(cmd)
            cmd = 'rm ' + outprefix + '.temp_fir.fastq'
            os.system(cmd)

    outfileFASTQ1.close()
    outfileFASTQ2.close()

run()