##################################
#                                #
# Last modified 01/21/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import scipy.stats
from sets import Set
import time

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s ASE_output copies_per_cell_file geneFieldID copies_per_cell_fieldID catch_rate evaluation_limit_factor copies_limit_cap p_value_cutoff outfile' % sys.argv[0]
        print '\tassumed ASE format:'
        print '\t#geneName\tgeneID\tchr\tmat_collapsed_reads\tpat_collapsed_reads\tmat_fraction\tpat_fraction\tp-value'
        print '\tcatch rate paramter should be a float, for example 0.10'
        print '\tNote: multiple entires for the same genee in the copies_per_cell file will be summed'
        print '\tevaluation_limit_factor referse to the number up to which catch rate p-value estimation should be carried out, which will be the evaluation_limit_factor times the number of copies per cell'
        print '\tfor large number of copies, the calculaiton becomesi impractical; the minimum of the copies_limit_cap or the evaluation_limit_factor times the number of copies per cell will be used'
        print '\tuse the p-value cutoff in order to save computational time; genes for which either the read or copies p-values are above that cutoff will not be evaluated for catch rate effects'
        sys.exit(1)

    ASE=sys.argv[1]
    copies=sys.argv[2]
    geneID = int(sys.argv[3])
    copiesID = int(sys.argv[4])
    catch_rate = float(sys.argv[5])
    evaluation_factor = float(sys.argv[6])
    copies_limit_cap = int(sys.argv[7])
    p_value_cutoff = float(sys.argv[8])
    outfilename = sys.argv[9]

    CopiesDict={}

    linelist=open(copies)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields=line.strip().split('\t')
        gene = fields[geneID]
        copies = float(fields[copiesID])
        if CopiesDict.has_key(gene):
            pass
        else:
            CopiesDict[gene]=0
        CopiesDict[gene] += copies

    outfile = open(outfilename, 'w')

    CatchRatePValueDict = {}

    linelist=open(ASE)
    g = 0 
    for line in linelist:
        if line.startswith('#'):
            outline = line.strip() + '\t' + 'copies' + '\t' + 'mat_copies' + '\t' + 'pat_copies' + '\t' + 'copies_p-value' + '\t' + 'catch_rate_estimation_p-value'
            outfile.write(outline + '\n')
            continue
        fields=line.strip().split('\t')
        g+=1
#        if g % 100 == 0:
#            print g, 'genes processed'
        gene = fields[0]
        try:
            copies = CopiesDict[gene]
        except:
            print gene, 'not found in expression file, skipping'
        mat_fraction = float(fields[5])
        pat_fraction = float(fields[6])
        reads_pvalue = float(fields[7])
        mat_copies = int(mat_fraction*copies)
        pat_copies = int(pat_fraction*copies)
        copies_pvalue =  scipy.stats.binom_test(mat_copies, mat_copies + pat_copies, 0.5)
        copies_limit = min(int(evaluation_factor*copies),copies_limit_cap)
        CEa = min(mat_copies,pat_copies)
        CEA = max(mat_copies,pat_copies)
        catch_rate_pvalue = 0
        start = time.time()
        print gene, g, 'copies:', int(copies), int(copies_limit), 'mat copies:', mat_copies, 'pat copies:', pat_copies, (copies_limit*copies)*CEa*copies_limit
#        start2 = time.time()
        if copies_pvalue >= p_value_cutoff or reads_pvalue >= p_value_cutoff:
            catch_rate_pvalue = 'not_calculated'
        else:
            if CatchRatePValueDict.has_key((CEa,CEA)):
                catch_rate_pvalue = CatchRatePValueDict[(CEa,CEA)]
            else:
                TotalNB = 0
                for CCa in range(int(copies),copies_limit):
                    rvNB = scipy.stats.nbinom(copies,catch_rate)
                    TotalNB += rvNB.pmf(CCa - int(copies))
                print int(copies),copies_limit, 'TotalNB', TotalNB
                for CCa in range(int(copies),copies_limit):
#                    print time.time() - start2, int(copies), int(copies_limit), CCa, catch_rate_pvalue
                    rvNB = scipy.stats.nbinom(copies,catch_rate)
                    NB = rvNB.pmf(CCa - int(copies))
#                    start2 = time.time()
                    if NB == 0:
                        continue
                    NB = NB/TotalNB
                    rvCEa = scipy.stats.binom(int(0.5*CCa),catch_rate)
                    i=0
                    p1 = rvCEa.pmf(i)
                    if p1 == 0:
                        pass
                    else:
                        for j in range(CEA,max(CEA,int(0.5*CCa))):
                            p2 = rvCEa.pmf(j)
                            if p2 != 0:
                                catch_rate_pvalue += NB*p1*p2
                    for i in range(1,CEa+1):
                        p1 = rvCEa.pmf(i)
                        if p1 == 0:
                            continue
                        for j in range(CEA,max(CEA,int(0.5*CCa))):
                            p2 = rvCEa.pmf(j)
                            if p2 != 0:
                               catch_rate_pvalue += NB*p1*p2
                CatchRatePValueDict[(CEa,CEA)] = catch_rate_pvalue
        outline = line.strip() + '\t' + str(copies) + '\t' + str(mat_copies) +  '\t' + str(pat_copies) +  '\t' + str(copies_pvalue) +  '\t' + str(catch_rate_pvalue)
        elapsed = (time.time() - start)
        print elapsed, outline
        outfile.write(outline +'\n')

    outfile.close()

run()