##################################
#                                #
# Last modified 04/09/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import scipy.stats
import numpy
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s FPKM_table geneFieldID FPKMFieldID minFPKMCutoff transcripts_per_cell catch_rate number_cells minFractionOfCellsAGeneIsExpressedIn outfile' % sys.argv[0]
        sys.exit(1)

    FPKMtable = sys.argv[1]
    geneFieldID = int(sys.argv[2])
    FPKMFieldID = int(sys.argv[3])
    FPKMCutoff = float(sys.argv[4])
    T = int(sys.argv[5])
    Psmc = float(sys.argv[6])
    N = int(sys.argv[7])
    minFractionCells = float(sys.argv[8])
    outfilename = sys.argv[9]

    GeneDict = {}

    FPKMList = []

    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = fields[geneFieldID]
        FPKM = float(fields[FPKMFieldID])
        if FPKM < FPKMCutoff:
            continue
        GeneDict[gene]={}
        GeneDict[gene]['originalFPKM'] = FPKM
        FPKMList.append(FPKM)
    
    FPKMList.sort()
    step = len(FPKMList)/10
    print len(FPKMList), 'genes found'

    PG = {}

    for i in xrange(1,11,1):
        print i*step
        PG[i/10.] = FPKMList[i*step-1]

    PG[0] = FPKMList[0]
    PG[1] = max(FPKMList)+1

    print PG

    PGGroups = PG.keys()
    PGGroups.sort()

    for gene in GeneDict:
        GeneDict[gene]['CLG'] = 0
        GeneDict[gene]['CCG'] = 0
        GeneDict[gene]['CellsExpressedIn'] = 0
        FPKM = GeneDict[gene]['originalFPKM']
        for PGgroup in PGGroups[0:10]:
            if FPKM >= PG[PGgroup] and FPKM < PG[PGGroups[PGGroups.index(PGgroup)+1]]:
                GeneDict[gene]['PG'] = PGGroups[PGGroups.index(PGgroup)+1]
                break
        mu = GeneDict[gene]['PG']
        sigma = math.fabs(0.9 - GeneDict[gene]['PG'])
        GaussianDenistyNormFactor = scipy.stats.norm.cdf(1,mu,sigma) - scipy.stats.norm.cdf(0,mu,sigma)
        Picked = False
        while not Picked:
            p = random.gauss(mu,sigma)
            if p >= minFractionCells and p <= 1:
                PE = p
                Picked = True
                break
        GeneDict[gene]['PE'] = PE

    TotalFPKM = sum(FPKMList)
    TotalCCG = 0.0
    TotalCLG = 0.0

    for i in range(0,N):
        EGDict = {}
        TotalFPKMNi = 0
        for gene in GeneDict:
            p =  random.random()
            if p <= GeneDict[gene]['PE']:
                EGDict[gene] = 1
            else:
                EGDict[gene] = 0
            TotalFPKMNi += (EGDict[gene]*GeneDict[gene]['originalFPKM'])
            GeneDict[gene]['CellsExpressedIn'] += EGDict[gene]
        for gene in GeneDict:
            CCG = int(((EGDict[gene]*GeneDict[gene]['originalFPKM'])/TotalFPKMNi)*T)
            TotalCCG += CCG
            GeneDict[gene]['CCG'] += CCG
            for c in range(CCG):
                p =  random.random()
                if p <= Psmc:
                    GeneDict[gene]['CLG']+=1
                    TotalCLG += 1

    outfile = open(outfilename, 'w')

    genes = GeneDict.keys()
    genes.sort()

    outline = '#Gene\tFPKM\tP_e\tCellsExpressedIn\tActual_Copies\tCopies_In_Library\tActual_Copies_FPKM\tLibrary_FPKM'
    outfile.write(outline + '\n')

    for gene in genes:
        outline = gene + '\t' + str(GeneDict[gene]['originalFPKM']) + '\t' + str(GeneDict[gene]['PE']) + '\t' + str(GeneDict[gene]['CellsExpressedIn'])
        outline = outline + '\t' + str(GeneDict[gene]['CCG']) + '\t' + str(GeneDict[gene]['CLG'])
        outline = outline + '\t' + str((GeneDict[gene]['CCG']/TotalCCG)*TotalFPKM) + '\t' + str((GeneDict[gene]['CLG']/TotalCLG)*TotalFPKM)
        outfile.write(outline + '\n')

    outfile.close()

run()