##################################
#                                #
# Last modified 2016/09/02       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
# import scipy.stats
# import numpy
import math
import random
from sets import Set
import time
from multiprocessing import Pool
from threading import Thread

def simulate_converted_transcripts((FPKMGeneDict,TranscriptDict,CR1,CR2,TotalFPKM,Ncells,NtranscriptsPerCell)):

    genes = FPKMGeneDict.keys()
    for gene in genes:
        Tt = (FPKMGeneDict[gene]/TotalFPKM)*Ncells*NtranscriptsPerCell
        Tint = Tt - int(Tt)
        PickT = random.random()
        T = int(Tt)
        if PickT <= Tint:
            T += 1
        TranscriptDict[gene] = {}
        TranscriptDict[gene]['original'] = T
        TranscriptDict[gene]['visible'] = 0
        s = 0
        while s < T:
            s += 1
            CatchP = random.random()
            if CatchP <= CR1:
                TranscriptDict[gene]['visible'] += 1
        TranscriptDict[gene]['seq'] = 0
        s = 0
        while s < TranscriptDict[gene]['visible']:
            s += 1
            CatchP = random.random()
            if CatchP <= CR2:
                TranscriptDict[gene]['seq'] += 1

    return TranscriptDict

def simulate_read_counts((TranscriptDict,SamplingP)):

    for gene in TranscriptDict.keys():
        TranscriptDict[gene]['reads'] = 0.0
        s = 0
        while s < TranscriptDict[gene]['seq']:
            s+=1
            SeqP = random.random()
            if SeqP <= SamplingP:
                TranscriptDict[gene]['reads'] += 1

    return TranscriptDict

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s FPKM_table geneNameFieldID geneFPKMFieldID Number_cells Number_transcripts_per_cell Catch_rate_1 Catch_rate_2 Read_Number allowed_deviation outfile [-p threads]' % sys.argv[0]
        print '\tCatch Rate #1: library building catch rate (ligation of tails, etc.)'
        print '\tCatch Rate #2: sequencing rate (pore loading at infinite sequencing depth, etc.)'
        print '\tAllowed deviation: the permissible deviation from the true quantification (in TPMs); should be a float between 0 and 1'
        sys.exit(1)

    FPKMtable = sys.argv[1]
    geneNameFieldID = int(sys.argv[2])
    FPKMFieldID = int(sys.argv[3])
    Ncells = int(sys.argv[4])
    NtranscriptsPerCell = int(sys.argv[5])
    CR1 = float(sys.argv[6])
    CR2 = float(sys.argv[7])
    RN = int(sys.argv[8])
    AD = float(sys.argv[9])
    outfilename = sys.argv[10]

    NP = 1
    if '-p' in sys.argv:
        NP = int(sys.argv[sys.argv.index('-p') + 1])

    FPKMGeneDict = {}
    ReadDict = {}
    TranscriptDict = {}

    TotalFPKM = 0
    TotalTranscripts = Ncells*NtranscriptsPerCell

    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        gene = fields[geneNameFieldID]
        FPKM = float(fields[FPKMFieldID])
        FPKMGeneDict[gene] = FPKM
        TotalFPKM += FPKM
        ReadDict[gene] = 0

    print 'finished importing FPKM table'

    TotalOriginalTranscripts = 0.0
    TotalVisibleTranscripts = 0.0
    TotalSequencableTranscripts = 0.0

    N = NtranscriptsPerCell*Ncells
    Nt = 0
    Nv = 0
    Nvv = 0
    Ns = 0
    Nss = 0

    geneNamesList = FPKMGeneDict.keys()
    random.shuffle(geneNamesList)

    TDictArray = []

    k = len(geneNamesList)/NP
 
    j=0
    for i in range(NP):
        FDict = {}
        TDict = {}
        if i+1 == NP:
            while j < len(geneNamesList):
                gene = geneNamesList[j]
                FDict[gene] = FPKMGeneDict[gene]
                j += 1
        else:
            while j < k*(i+1):
                gene = geneNamesList[j]
                FDict[gene] = FPKMGeneDict[gene]
                j += 1
        TDictArray.append((FDict,TDict,CR1,CR2,TotalFPKM,Ncells,NtranscriptsPerCell))

    p = Pool(NP)
    VisibleTranscriptDicts = p.map(simulate_converted_transcripts, TDictArray)

    print 'finished simulating converted transcripts'
    print 'finished simulating sequenceable transcripts'

    for VTD in VisibleTranscriptDicts:
        for gene in VTD.keys():
            TranscriptDict[gene] = {}
            TranscriptDict[gene]['original'] = VTD[gene]['original']
            TranscriptDict[gene]['visible'] = VTD[gene]['visible']
            TranscriptDict[gene]['seq'] = VTD[gene]['seq']
            TotalOriginalTranscripts += VTD[gene]['original']
            TotalVisibleTranscripts += VTD[gene]['visible']
            TotalSequencableTranscripts += VTD[gene]['seq']

    print 'TotalOriginalTranscripts', TotalOriginalTranscripts
    print 'TotalVisibleTranscripts', TotalVisibleTranscripts
    print 'TotalSequencableTranscripts:', TotalSequencableTranscripts

    SamplingP = RN/(TotalSequencableTranscripts + 0.0)
    print 'SamplingP', SamplingP

    TDictArray = []
    k = len(geneNamesList)/NP

    j=0
    for i in range(NP):
        TDict = {}
        if i+1 == NP:
            while j < len(geneNamesList):
                gene = geneNamesList[j]
                TDict[gene] = TranscriptDict[gene]
                j += 1
        else:
            while j < k*(i+1):
                gene = geneNamesList[j]
                TDict[gene] = TranscriptDict[gene]
                j += 1
        TDictArray.append((TDict,SamplingP))

    SequencedTranscripts = 0.0

    p = Pool(NP)
    SequencedTranscriptDicts = p.map(simulate_read_counts, TDictArray)

    print 'finished simulating read counts'

    for VTD in SequencedTranscriptDicts:
        for gene in VTD.keys():
            TranscriptDict[gene]['reads'] = VTD[gene]['reads']
            SequencedTranscripts += VTD[gene]['reads']

    print 'SequencedTranscripts:', SequencedTranscripts

    outfile = open(outfilename, 'w')

    outline = '#gene\toriginal_FPKM\toriginal_TPM\tSampled_Reads\tSampled_TPM\tWithin_Range?'
    outfile.write(outline + '\n')

    genes = FPKMGeneDict.keys()
    genes.sort()

    for gene in genes:
        outline = gene + '\t' + str(FPKMGeneDict[gene])
        TPMoriginal = (TranscriptDict[gene]['original']/TotalOriginalTranscripts)*1000000
        outline = outline + '\t' + str(TPMoriginal) + '\t' + str(TranscriptDict[gene]['reads'])
        TPMsampled = (TranscriptDict[gene]['reads']/(SequencedTranscripts + 0.0))*1000000
        outline = outline + '\t' + str(TPMsampled)
        if TPMsampled == TPMoriginal and TPMoriginal == 0.0:
            WithinRange = 'n/a'
        else:
            if (math.fabs(TPMsampled - TPMoriginal))/TPMoriginal <= AD:
                WithinRange = 1
            else:
                WithinRange = 0
        outline = outline + '\t' + str(WithinRange)
        outfile.write(outline + '\n')

    outfile.close()

run()

#    threads = [None] * NP
#    i=0
#    for (FDict,TDict,CR1,TotalFPKM,Ncells,NtranscriptsPerCell) in TDictArray:
#        print i, len(FDict.keys()), len(TDict.keys())
#        threads[i] = Thread(target = simulate_converted_transcripts, args = (FDict,TDict,CR1,TotalFPKM,Ncells,NtranscriptsPerCell))
#        threads[i].start()
#        i+=1

#    for i in range(len(threads)):
#        threads[i].join()
