##################################
#                                #
# Last modified 03/13/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s transcript_FPKM_table NumberReadPairs transcriptome-fasta-directory premRNA-fasta-directory IntronicFraction path_to_mason GTF outprefix [-minTranscriptLength bp]' % sys.argv[0]
        print '\t transcriptome-fasta-directory should contain one separate fasta file for each transcript; those should be labelled as follows: ZZZ3:ENSG00000036549.8:ZZZ3-002:ENST00000370798.1.fa'
        print '\t premRNA-fasta-directory should contain one separate fasta file for each premRNA; those should be labelled as follows: ZZZ3:ENSG00000036549.8:ZZZ3-002:ENST00000370798.1:pre_mRNA.fa'
        print '\t the transcript_FPKM_table files should contain the specified FPKM value (for transcripts only, the pre-mRNA reads will be calculated based on those)'
        print '\t transcript_FPKM_table format: #GeneID\tGeneName\tGeneFPKM\tTranscriptID\tTranscriptName\tTranscriptFPKM\tFMI'
        print '\t note: the script will generate stranded data with the first end of a pair being sense'
        sys.exit(1)

    FPKMtable = sys.argv[1]
    N = int(sys.argv[2])
    mRNAfastaDir = sys.argv[3]
    premRNAfastaDir = sys.argv[4]
    IF = float(sys.argv[5])
    read_length = int(sys.argv[6])
    fragment_length = int(sys.argv[7])
    frag_length_stdev = int(sys.argv[8])
    mason = sys.argv[9]
    GTF = sys.argv[10]
    outprefix = sys.argv[11]

    doMinTL = False
    if '-minTranscriptLength' in sys.argv:
        doMinTL = True
        minTL = int(sys.argv[sys.argv.index('-minTranscriptLength') + 1])
        print 'will exclude transcripts shorter than', minTL

    TranscriptDict = {}

    linelist = open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        start=int(fields[3])
        stop=int(fields[4])
        strand=fields[6]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptID, transcriptName)
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]={}
            TranscriptDict[transcript]['exons']=[]
        TranscriptDict[transcript]['exons'].append((chr,start,stop,strand))

    print 'finished parsing GTF file'

    print 'before', len(TranscriptDict.keys())
    for transcript in TranscriptDict.keys():
        chromosomes = []
        TL = 0
        for (chr,start,stop,strand) in TranscriptDict[transcript]['exons']:
            chromosomes.append(chr)
            TL += (math.fabs(stop-start))
        if doMinTL:
            if TL <= minTL:
                del TranscriptDict[transcript]
        chromosomes = list(Set(chromosomes))
        if len(chromosomes) > 1:
            del TranscriptDict[transcript]

    print 'after:', len(TranscriptDict.keys())

    TotalFPKM = 0
    TotalFPM = 0
    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[0]
        geneName = fields[1]
        transcriptID = fields[3]
        transcriptName = fields[4]
        FPKM = float(fields[5])
        TotalFPKM += FPKM
        transcript = (geneID, geneName, transcriptID, transcriptName)
        length = 0
        if TranscriptDict.has_key(transcript):
            pass
        else:
            continue
        for (chr,start,stop,strand) in TranscriptDict[transcript]['exons']:
            length += (math.fabs(stop-start))
        lengthK = length/1000.
        FPM = FPKM*lengthK
        TotalFPM += FPM
        if TranscriptDict.has_key(transcript):
            TranscriptDict[transcript]['FPKM'] = FPKM
        else:
            print transcript

    print 'finished parsing simulated FPKMs', TotalFPKM, TotalFPM

    outfile_rescaled_FPKM = open(outprefix + '.rescaled_FPKM', 'w')
    outline = '#geneName\tgeneID\ttranscriptName\ttranscriptID\tFPM\tlength\tFPKM\tIF_reads\tEF_reads'
    outfile_rescaled_FPKM.write(outline + '\n')
    outfileFASTQ1 = open(outprefix + '.read1.fastq', 'w')
    outfileFASTQ2 = open(outprefix + '.read2.fastq', 'w')

    s = 0
    TotalReadNumber = 0
    for transcript in TranscriptDict.keys():
        if TranscriptDict[transcript].has_key('FPKM'):
            pass
        else:
            continue
        FPKM = TranscriptDict[transcript]['FPKM']
        length = 0
        for (chr,start,stop,strand) in TranscriptDict[transcript]['exons']:
            length += (math.fabs(stop-start))
        lengthK = length/1000.
        FPM = FPKM*lengthK
        rescaled_FPM = FPM/(TotalFPM/1000000.)
        rescaled_FPKM = FPKM/(TotalFPM/1000000.)
        TranscriptDict[transcript]['rescaled_FPKM'] = rescaled_FPKM
        reads = rescaled_FPM*(N/1000000.)
        IFreads = int(IF*reads)
        EFreads = int((1-IF)*reads)
        (geneID, geneName, transcriptID, transcriptName) = transcript
        outline = geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + str(rescaled_FPKM) + '\t' + str(length) + '\t' + str(rescaled_FPM) + '\t' + str(IFreads) + '\t' + str(EFreads)
        outfile_rescaled_FPKM.write(outline + '\n')
        if reads == 0:
            continue
        if EFreads >= 1:
            transcript_file_name = geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID+ '.fa'
            cmd = mason + ' illumina --num-reads ' + str(EFreads) + ' --read-length ' + str(read_length) + ' --library-length-mean ' + str(fragment_length) + ' -le ' + str(frag_length_stdev) + ' --include-read-information --forward-only --simulate-qualities --mate-pairs --read-name-prefix temp --prob-insert 0 --prob-delete 0 --haplotype-snp-rate 0 --haplotype-indel-rate 0 --output-file ' + outprefix + '.temp ' + mRNAfastaDir + '/' + transcript_file_name
            os.system(cmd)
            linelist = open(outprefix + '.temp.sam')
            readDict = {}
            for line in linelist:
                if line.startswith('@'):
                    continue
                fields = line.strip().split('\t')
                readID = fields[0]
                if readDict.has_key(readID):
                    pass
                else:
                    readDict[readID] = {}
                if fields[1] == '99':
                    end = 1
                elif fields[1] == '147':
                    end = 2
                else:
                    print 'FLAG field other than 99 or 147 encountered, exiting'
                    print fields
                    sys.exit(1)
                readDict[readID][end] = fields
            for readID in readDict.keys():
                fields1 = readDict[readID][1]
                fields2 = readDict[readID][2]
                pos1 = fields1[3]
                pos2 = fields2[3]
                CIGAR1 = fields1[5]
                CIGAR2 = fields2[5]
                transcript = fields[2]
                TotalReadNumber += 1
                line1 = '@' + transcript + ':pos_read1=' + pos1 + ':pos_read2=' + pos2 + ':CIGAR_read1=' + CIGAR1 + ':CIGAR_read2=' + CIGAR2 + str(TotalReadNumber) + '_#0/1' + '\n'
                line2 = fields1[9] + '\n'
                line3 = '+\n'
                line4 = fields1[10] + '\n'
                outfileFASTQ1.write(line1)
                outfileFASTQ1.write(line2)
                outfileFASTQ1.write(line3)
                outfileFASTQ1.write(line4)
                line1 = '@' + transcript + ':pos_read1=' + pos1 + ':pos_read2=' + pos2 + ':CIGAR_read1=' + CIGAR1 + ':CIGAR_read2=' + CIGAR2 + str(TotalReadNumber) + '_#0/2' + '\n'
                line2 = getReverseComplement(fields2[9]) + '\n'
                line3 = '+\n'
                line4 = fields2[10][::-1] + '\n'
                outfileFASTQ2.write(line1)
                outfileFASTQ2.write(line2)
                outfileFASTQ2.write(line3)
                outfileFASTQ2.write(line4)
        if IFreads >= 1:
            pre_mRNA_file_name = geneName + ':' + geneID + ':' + transcriptName + ':' + transcriptID + ':' + 'pre_mRNA.fa'
            cmd = mason + ' illumina --num-reads ' + str(IFreads) + ' --read-length ' + str(read_length) + ' --library-length-mean ' + str(fragment_length) + ' -le ' + str(frag_length_stdev) + ' --include-read-information --forward-only --simulate-qualities --mate-pairs --read-name-prefix temp --prob-insert 0 --prob-delete 0 --haplotype-snp-rate 0 --haplotype-indel-rate 0 --output-file ' + outprefix + '.temp ' + premRNAfastaDir + '/' + pre_mRNA_file_name
            os.system(cmd)
            linelist = open(outprefix + '.temp.sam')
            readDict = {}
            for line in linelist:
                if line.startswith('@'):
                    continue
                fields = line.strip().split('\t')
                readID = fields[0]
                if readDict.has_key(readID):
                    pass
                else:
                    readDict[readID] = {}
                if fields[1] == '99':
                    end = 1
                elif fields[1] == '147':
                    end = 2
                else:
                    print 'FLAG field other than 99 or 147 encountered, exiting'
                    print fields
                    sys.exit(1)
                readDict[readID][end] = fields
            for readID in readDict.keys():
                fields1 = readDict[readID][1]
                fields2 = readDict[readID][2]
                pos1 = fields1[3]
                pos2 = fields2[3]
                CIGAR1 = fields1[5]
                CIGAR2 = fields2[5]
                transcript = fields[2]
                TotalReadNumber += 1
                line1 = '@' + transcript + ':pos_read1=' + pos1 + ':pos_read2=' + pos2 + ':CIGAR_read1=' + CIGAR1 + ':CIGAR_read2=' + CIGAR2 + str(TotalReadNumber) + '_#0/1' + '\n'
                line2 = fields1[9] + '\n'
                line3 = '+\n'
                line4 = fields1[10] + '\n'
                outfileFASTQ1.write(line1)
                outfileFASTQ1.write(line2)
                outfileFASTQ1.write(line3)
                outfileFASTQ1.write(line4)
                line1 = '@' + transcript + ':pos_read1=' + pos1 + ':pos_read2=' + pos2 + ':CIGAR_read1=' + CIGAR1 + ':CIGAR_read2=' + CIGAR2 + str(TotalReadNumber) + '_#0/2' + '\n'
                line2 = getReverseComplement(fields2[9]) + '\n'
                line3 = '+\n'
                line4 = fields2[10][::-1] + '\n'
                outfileFASTQ2.write(line1)
                outfileFASTQ2.write(line2)
                outfileFASTQ2.write(line3)
                outfileFASTQ2.write(line4)

    outfileFASTQ1.close()
    outfileFASTQ2.close()

run()