##################################
#                                #
# Last modified 2016/08/28       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
# import scipy.stats
# import numpy
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s FPKM_table geneNameFieldID geneFPKMFieldID Number_cells Number_transcripts_per_cell Catch_rate_1 Catch_rate_2 Read_Number allowed_deviation outfile' % sys.argv[0]
        print '\tCatch Rate #1: library building catch rate (ligation of tails, etc.)'
        print '\tCatch Rate #2: sequencing rate (pore loading at infinite sequencing depth, etc.)'
        print '\tAllowed deviation: the permissible deviation from the true quantification (in TPMs); should be a float between 0 and 1'
        sys.exit(1)

    FPKMtable = sys.argv[1]
    geneNameFieldID = int(sys.argv[2])
    FPKMFieldID = int(sys.argv[3])
    Ncells = int(sys.argv[4])
    NtranscriptsPerCell = int(sys.argv[5])
    CR1 = float(sys.argv[6])
    CR2 = float(sys.argv[7])
    RN = int(sys.argv[8])
    AD = float(sys.argv[9])
    outfilename = sys.argv[10]

    FPKMGeneDict = {}
    ReadDict = {}
    TranscriptDict = {}

    TotalFPKM = 0
    TotalTranscripts = Ncells*NtranscriptsPerCell

    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        gene = fields[geneNameFieldID]
        FPKM = float(fields[FPKMFieldID])
        FPKMGeneDict[gene] = FPKM
        TotalFPKM += FPKM
        ReadDict[gene] = 0

    print 'finished importing FPKM table'

    TotalOriginalTranscripts = 0.0
    TotalVisibleTranscripts = 0.0

    N = NtranscriptsPerCell*Ncells
    Nt = 0
    Nv = 0
    Nvv = 0
    Ns = 0
    Nss = 0

    for gene in FPKMGeneDict.keys():
#        print gene, FPKMGeneDict[gene], TotalFPKM, Ncells, NtranscriptsPerCell
#        print 'FPKM fraction:', FPKMGeneDict[gene]/TotalFPKM
        Tt = (FPKMGeneDict[gene]/TotalFPKM)*Ncells*NtranscriptsPerCell
#        print 'Tt', Tt
        Tint = Tt - int(Tt)
#        print 'Tint', Tint
        PickT = random.random()
        T = int(Tt)
        if PickT <= Tint:
            T += 1
#        print 'T', T
        TranscriptDict[gene] = {}
        TranscriptDict[gene]['original'] = T
        TotalOriginalTranscripts += T
        TranscriptDict[gene]['visible'] = 0
        i = 0
        while i < T:
            i+=1
            Nt += 1
            if Nt % 10000000 == 0:
                print Nt,'/',N, 'original transcripts processed'
            CatchP = random.random()
            if CatchP <= CR1:
                TranscriptDict[gene]['visible'] += 1
                TotalVisibleTranscripts += 1
                Nv += 1
#        print 'v:', TranscriptDict[gene]['visible']

    print 'finished simulating converted transcripts'
    print 'TotalOriginalTranscripts', TotalOriginalTranscripts
    print 'TotalVisibleTranscripts', TotalVisibleTranscripts

    TotalSequencableTranscripts = 0.0

    for gene in FPKMGeneDict.keys():
        TranscriptDict[gene]['seq'] = 0
        i = 0
        while i < TranscriptDict[gene]['visible']:
            i+=1
            CatchP = random.random()
            Nvv += 1
            if Nvv % 10000000 == 0:
                print Nvv,'/',Nv, 'converted transcripts processed'
            if CatchP <= CR2:
                TranscriptDict[gene]['seq'] += 1
                TotalSequencableTranscripts += 1
                Ns += 1

    print 'finished simulating sequenceable transcripts'
    print 'TotalSequencableTranscripts:', TotalSequencableTranscripts

    SamplingP = RN/(TotalSequencableTranscripts + 0.0)
    print 'SamplingP', SamplingP

    SequencedTranscripts = 0.0

    for gene in FPKMGeneDict.keys():
        TranscriptDict[gene]['reads'] = 0.0
        i = 0
        while i < TranscriptDict[gene]['seq']:
            i+=1
            Nss += 1
            if Nss % 10000000 == 0:
                print Nss,'/',Ns, 'sequenceable transcripts processed'
            SeqP = random.random()
            if SeqP <= SamplingP:
                TranscriptDict[gene]['reads'] += 1
                SequencedTranscripts += 1

    print 'finished simulating read counts'

    outfile = open(outfilename, 'w')

    outline = '#gene\toriginal_FPKM\toriginal_TPM\tSampled_Reads\tSampled_TPM\tWithin_Range?'
    outfile.write(outline + '\n')

    genes = FPKMGeneDict.keys()
    genes.sort()

    for gene in genes:
        outline = gene + '\t' + str(FPKMGeneDict[gene])
        TPMoriginal = (TranscriptDict[gene]['original']/TotalOriginalTranscripts)*1000000
        outline = outline + '\t' + str(TPMoriginal) + '\t' + str(TranscriptDict[gene]['reads'])
        TPMsampled = (TranscriptDict[gene]['reads']/(SequencedTranscripts + 0.0))*1000000
        outline = outline + '\t' + str(TPMsampled)
#        print TPMoriginal, TPMsampled, FPKMGeneDict[gene], TranscriptDict[gene]['original'], TranscriptDict[gene]['visible'], TranscriptDict[gene]['seq'], TranscriptDict[gene]['reads']
        if TPMsampled == TPMoriginal and TPMoriginal == 0.0:
            WithinRange = 'n/a'
        else:
            if (math.fabs(TPMsampled - TPMoriginal))/TPMoriginal <= AD:
                WithinRange = 1
            else:
                WithinRange = 0
        outline = outline + '\t' + str(WithinRange)
        outfile.write(outline + '\n')

    outfile.close()

run()