##################################
#                                #
# Last modified 2017/04/18       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import math
import string
from sets import Set
import os
from multiprocessing import Pool
from threading import Thread
import numpy as np
from patsy import dmatrices, dmatrix, demo_data
from scipy.stats import nbinom
# from statsmodels import robust
from statsmodels.discrete.discrete_model import NegativeBinomial
import statsmodels.api as sm

def BayesianUpdateMu(alpha,alpha_mean,alpha_std):

    return (alpha/alpha_std**2 + alpha_mean/alpha_std**2)/(1/alpha_std**2 + 1/alpha_std**2)

def _ll_nb2(y, X, beta, alph):
    mu = np.exp(np.dot(X, beta))
    size = 1/alph
    prob = size/(size+mu)
    ll = nbinom.logpmf(y, size, prob)
    return ll

def estimate_dispersion((NRCDict,DispersionDict)):

    FFFF = 0
    for region in NRCDict.keys():
        FFFF += 1
        if FFFF % 1000 == 0:
            print FFFF, 'regions processed'
        minusCC = []
        minusCCdummy = []
        minusSS = []
        minusSSdummy = []
        for normCounts in NRCDict[region]['minusCN']:
            if len(minusCC) == 0 and max(NRCDict[region]['minusCN']) == 0:
                minusCC.append(1)
                minusCCdummy.append(1)
            else:
                minusCC.append(normCounts)
                minusCCdummy.append(1)
        for normCounts in NRCDict[region]['minusSN']:
            if len(minusSS) == 0 and max(NRCDict[region]['minusSN']) == 0:
                minusSS.append(1)
                minusSSdummy.append(1)
            else:
                minusSS.append(normCounts)
                minusSSdummy.append(1)
        plusCC = []
        plusCCdummy = []
        plusSS = []
        plusSSdummy = []
        for normCounts in NRCDict[region]['plusCN']:
            if len(plusCC) == 0 and max(NRCDict[region]['plusCN']) == 0:
                plusCC.append(1)
                plusCCdummy.append(1)
            else:
                plusCC.append(normCounts)
                plusCCdummy.append(1)
        for normCounts in NRCDict[region]['plusSN']:
            if len(plusSS) == 0 and max(NRCDict[region]['plusSN']) == 0:
                plusSS.append(1)
                plusSSdummy.append(1)
            else:
                plusSS.append(normCounts)
                plusSSdummy.append(1)

        mod = NegativeBinomial(minusCC,minusCCdummy)
        res = mod.fit(disp=0)
        muCCminus = np.exp(res.params[0])
        alphaCCminus = res.params[1]
        try:
            alphaCCminus_bse = res.bse[1]
        except:
            alphaCCminus_bse = 'nan'

        mod = NegativeBinomial(minusSS,minusSSdummy)
        res = mod.fit(disp=0)
        muSSminus = np.exp(res.params[0])
        alphaSSminus = res.params[1]
        try:
            alphaSSminus_bse = res.bse[1]
        except:
            alphaSSminus_bse = 'nan'

        mod = NegativeBinomial(plusCC,plusCCdummy)
        res = mod.fit(disp=0)
        muCCplus = np.exp(res.params[0])
        alphaCCplus = res.params[1]
        try:
            alphaCCplus_bse = res.bse[1]
        except:
            alphaCCplus_bse = 'nan'

        mod = NegativeBinomial(plusSS,plusSSdummy)
        res = mod.fit(disp=0)
        muSSplus = np.exp(res.params[0])
        alphaSSplus = res.params[1]
        try:
            alphaSSplus_bse = res.bse[1]
        except:
            alphaSSplus_bse = 'nan'

        DispersionDict[region] = {'muCminus':muCCminus, 'alphaCminus':alphaCCminus, 'alphaCminus_bse':alphaCCminus_bse, 
                                  'muCplus':muCCplus, 'alphaCplus':alphaCCplus, 'alphaCplus_bse':alphaCCplus_bse, 
                                  'muSminus':muSSminus, 'alphaSminus':alphaSSminus, 'alphaSminus_bse':alphaSSminus_bse, 
                                  'muSplus':muSSplus, 'alphaSplus':alphaSSplus, 'alphaSplus_bse':alphaSSplus_bse}

    return DispersionDict

def sampling_p_values((NRCDict,PvalueDict,MFC,NSamplings)):

    FFFF = 0
    for region in NRCDict.keys():
        FFFF += 1
        if FFFF % 100 == 0:
            print FFFF, 'regions processed'
        plusSNmean = max(np.mean(NRCDict[region]['plusSN']),1.)
        plusCNmean = max(np.mean(NRCDict[region]['plusCN']),1.)
        minusSNmean = max(np.mean(NRCDict[region]['minusSN']),1.)
        minusCNmean = max(np.mean(NRCDict[region]['minusCN']),1.)
        alphaSplus_decile = NRCDict[region]['alphaSplus_decile']
        alphaCplus_decile = NRCDict[region]['alphaCplus_decile']
        alphaSminus_decile = NRCDict[region]['alphaSminus_decile']
        alphaCminus_decile = NRCDict[region]['alphaCminus_decile']
        alphaSandCplus_decile_STARR = NRCDict[region]['alphaSandCplus_decile_STARR']
        alphaSandCplus_decile_Control = NRCDict[region]['alphaSandCplus_decile_Control']
        plusSandCMean = max(NRCDict[region]['plusSandCMean'],1.)
        alphaSandCminus_decile_STARR = NRCDict[region]['alphaSandCminus_decile_STARR']
        alphaSandCminus_decile_Control = NRCDict[region]['alphaSandCminus_decile_Control']
        minusSandCMean = max(NRCDict[region]['minusSandCMean'],1.)
               
        r = 1./alphaSandCplus_decile_STARR
        p = plusSandCMean*alphaSandCplus_decile_STARR/(1 + plusSandCMean*alphaSandCplus_decile_STARR)
        plusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1./alphaSandCplus_decile_Control
        p = plusSandCMean*alphaSandCplus_decile_Control/(1 + plusSandCMean*alphaSandCplus_decile_Control)
        plusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1./alphaSandCminus_decile_STARR
        p = minusSandCMean*alphaSandCminus_decile_STARR/(1 + minusSandCMean*alphaSandCminus_decile_STARR)
        minusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1./alphaSandCminus_decile_Control
        p = minusSandCMean*alphaSandCminus_decile_Control/(1 + minusSandCMean*alphaSandCminus_decile_Control)
        minusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        PlusFC = (plusSNmean + 1.)/(plusCNmean + 1.)
        MinusFC = (minusSNmean + 1.)/(minusCNmean + 1.)
        PlusMinusRatio = PlusFC/MinusFC
        meanPlusMinusFC = np.mean([PlusFC,MinusFC])

        r = 1/alphaSplus_decile
        p = plusCNmean*meanPlusMinusFC*alphaSplus_decile/(1 + plusCNmean*meanPlusMinusFC*alphaSplus_decile)
        PMplusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1/alphaCplus_decile	
        p = plusCNmean*alphaCplus_decile/(1 + plusCNmean*alphaCplus_decile)
        PMplusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1/alphaSminus_decile
        p = minusCNmean*meanPlusMinusFC*alphaSminus_decile/(1 + minusCNmean*meanPlusMinusFC*alphaSminus_decile)
        PMminusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1/alphaCminus_decile
        p = minusCNmean*alphaCminus_decile/(1 + minusCNmean*alphaCminus_decile)
        PMminusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        plusMFCpass = 0.0
        minusMFCpass = 0.0
        pmratiopass = 0.0
        for i in range(NSamplings):
            plusRatio = (plusSNsamples[i] + 1.)/(plusCNsamples[i] + 1.)
            if plusRatio > PlusFC:
                plusMFCpass += 1
            minusRatio = (minusSNsamples[i] + 1.)/(minusCNsamples[i] + 1.)
            if minusRatio > MinusFC:
                minusMFCpass += 1
        p_plus = plusMFCpass/NSamplings
        p_minus = minusMFCpass/NSamplings
        if PlusFC == MinusFC:
            pm = 1
        else:
            for i in range(NSamplings):
                pmratio = ((PMplusSNsamples[i] + 1.)/(PMplusCNsamples[i] + 1.))/((PMminusSNsamples[i] + 1.)/(PMminusCNsamples[i] + 1.))
                if PlusMinusRatio > 1 and pmratio > PlusMinusRatio:
                    pmratiopass += 1
                if PlusMinusRatio < 1 and pmratio < PlusMinusRatio:
                    pmratiopass += 1
            pm = pmratiopass/NSamplings

#        if region == 'chr10:74056075-74057075':
#            print 'chr10:74056075-74057075'
#            print plusSNmean, plusCNmean, minusSNmean, minusCNmean
#            print alphaSplus_decile, alphaCplus_decile, alphaSminus_decile, alphaCminus_decile, alphaSandCplus_decile_STARR, alphaSandCplus_decile_Control
#            print plusSandCMean, alphaSandCminus_decile_STARR, alphaSandCminus_decile_Control, minusSandCMean
#            print p_plus, p_minus, pm
        
        PvalueDict[region] = (p_plus,p_minus,pm,PlusMinusRatio,PlusFC,MinusFC)

    return PvalueDict

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s table STARRfields Controlfields minFoldChange sampling_size outprefix [-norm median-of-ratios] [-p threads] [-dispBins number]' % sys.argv[0]
        print '\tNote!!!!: have the table include reads from the rest of the genome, not just the regions; otherwise read depth normalization will fail'
        print '\tNote: the script assumes the following format of regions:'
        print '\t\t\tchr10:102658058-102659058|+  for plus strand regions'
        print '\t\t\tchr10:102658058-102659058|-  for minus strand regions'
        print '\t\t\tcomplement.50K_chunks::chrY:9550001-9600001 for between-regions features (which will be ignored)'
        print '\tNote: STARRfields Controlfields should be comma separated'
        sys.exit(1)

    readCounts = sys.argv[1]
    STARRFieldIDs = []
    fields = sys.argv[2].split(',')
    for ID in fields:
        STARRFieldIDs.append(int(ID))
    ControlFieldIDs = []
    fields = sys.argv[3].split(',')
    for ID in fields:
        ControlFieldIDs.append(int(ID))
    MFC = float(sys.argv[4])
    NSamplings = int(sys.argv[5])
    outprefix = sys.argv[6]

    dispBins = 10
    if '-dispBins' in sys.argv:
        dispBins = int(sys.argv[sys.argv.index('-dispBins') + 1])

    doLDS = True
    if '-norm' in sys.argv:
        if sys.argv[sys.argv.index('-norm') + 1] == 'median-of-ratios':
            print 'will apply the median-of-ratios normalization'
            doLDS = False

    NP = 1
    if '-p' in sys.argv:
        NP = int(sys.argv[sys.argv.index('-p') + 1])
        print 'will run on', NP, 'threads'

    TotalReadCountDict = {}
    for ID in STARRFieldIDs:
        TotalReadCountDict[ID] = 0
    for ID in ControlFieldIDs:
        TotalReadCountDict[ID] = 0

    readCountsDict = {}

    linelist = open(readCounts)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        for ID in STARRFieldIDs:
            TotalReadCountDict[ID] += float(fields[ID])
        for ID in ControlFieldIDs:
            TotalReadCountDict[ID] += float(fields[ID])
        if fields[0].startswith('complement'):
            continue
        region = fields[0].split('|')[0]
        strand = fields[0].split('|')[1]
        if readCountsDict.has_key(region):
            pass
        else:
            readCountsDict[region] = {}
            readCountsDict[region]['minusC'] = []
            readCountsDict[region]['minusS'] = []
            readCountsDict[region]['plusC'] = []
            readCountsDict[region]['plusS'] = []
            readCountsDict[region]['minusCN'] = []
            readCountsDict[region]['minusSN'] = []
            readCountsDict[region]['plusCN'] = []
            readCountsDict[region]['plusSN'] = []
        if strand == '+':
            for ID in STARRFieldIDs:
                readCountsDict[region]['plusS'].append(float(fields[ID]))
            for ID in ControlFieldIDs:
                readCountsDict[region]['plusC'].append(float(fields[ID]))
        if strand == '-':
            for ID in STARRFieldIDs:
                readCountsDict[region]['minusS'].append(float(fields[ID]))
            for ID in ControlFieldIDs:
                readCountsDict[region]['minusC'].append(float(fields[ID]))

    print len(readCountsDict), 'region found'

    TRClist = []
    for ID in TotalReadCountDict.keys():
        TRClist.append(TotalReadCountDict[ID])

    MinimumTotalReadCounts = min(TRClist)

    LDSratioDict = {}
    for ID in TotalReadCountDict.keys():
        LDSratioDict[ID] = TotalReadCountDict[ID]/(MinimumTotalReadCounts + 0.0)
#        print ID, 'LDSratioDict[ID]', LDSratioDict[ID]
#    sys.exit(1)

    if doLDS:
        for region in readCountsDict.keys():
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['plusSN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusS'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['minusSN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['plusCN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusC'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['minusCN'].append(int(round(normCounts)))
                i += 1
    else:
        MedianRatiosDict = {}
        for ID in TotalReadCountDict.keys():
            MedianRatiosDict[ID] = []
        for region in readCountsDict.keys():
            i = 0
            n = 0
            Prod = 1
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            if Prod == 0:
                continue
            else:
                ProdNthroot = math.pow(Prod,(1./n))
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                i += 1
                MedianRatiosDict[ID].append(counts/ProdNthroot)
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                MedianRatiosDict[ID].append(counts/ProdNthroot)
                i += 1
            i = 0
            n = 0
            Prod = 1
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['minusS'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['minusC'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            if Prod == 0:
                continue
            else:
                ProdNthroot = math.pow(Prod,(1./n))
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['minusS'][i]
                i += 1
                MedianRatiosDict[ID].append(counts/ProdNthroot)
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['minusC'][i]
                MedianRatiosDict[ID].append(counts/ProdNthroot)
                i += 1
        MMratioDict = {}
        for ID in TotalReadCountDict.keys():
            MMratioDict[ID] = np.median(np.array(MedianRatiosDict[ID]))
#            print ID, 'MMratioDict[ID]', MMratioDict[ID]
#	        sys.exit(1)
        for region in readCountsDict.keys():
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['plusSN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusS'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['minusSN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['plusCN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusC'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['minusCN'].append(int(round(normCounts)))
                i += 1
            
    print 'finished read count normalization'

    print 'estimating dispersion'

    NRCArray = []
    regionlist = readCountsDict.keys()
    regionlist.sort()
    k = len(regionlist)/NP

    j=0
    for i in range(NP):
        NRCDict = {}
        DispersionDict = {}
        if i+1 == NP:
            while j < len(regionlist):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        else:
            while j < k*(i+1):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        NRCArray.append((NRCDict,DispersionDict))

    p = Pool(NP)
    DispersionDicts = p.map(estimate_dispersion, NRCArray)

    print 'finihed dispersion estimation'

    DispersionDict = {}

    outfile = open(outprefix + '.first_pass_params','w')

    outline = '#region'
    for ID in ControlFieldIDs: 
        outline = outline + '\t' + str(ID)
    for ID in STARRFieldIDs:
        outline = outline + '\t' + str(ID)
    outline =  outline + '\tmuControl\talphaControl\tmuSTARR\talphaSTARR'
    outfile.write(outline + '\n')

    for DD in DispersionDicts:
        for region in DD.keys():
            outline = region + '|+'
            i = 0
            for ID in STARRFieldIDs:
                outline = outline + '\t' + str(readCountsDict[region]['plusSN'][i])
                i+=1
            i = 0
            for ID in ControlFieldIDs: 
                outline = outline + '\t' + str(readCountsDict[region]['plusCN'][i])
                i+=1
            outline = outline + '\t' + str(DD[region]['muCplus'])
            outline = outline + '\t' + str(DD[region]['alphaCplus'])
            outline = outline + '\t' + str(DD[region]['muSplus'])
            outline = outline + '\t' + str(DD[region]['alphaSplus'])
            outfile.write(outline + '\n')

            outline = region + '|-'
            i = 0
            for ID in STARRFieldIDs:
                outline = outline + '\t' + str(readCountsDict[region]['minusSN'][i])
                i+=1
            i = 0
            for ID in ControlFieldIDs: 
                outline = outline + '\t' + str(readCountsDict[region]['minusCN'][i])
                i+=1
            outline = outline + '\t' + str(DD[region]['muCminus'])
            outline = outline + '\t' + str(DD[region]['alphaCminus'])
            outline = outline + '\t' + str(DD[region]['muSminus'])
            outline = outline + '\t' + str(DD[region]['alphaSminus'])
            outfile.write(outline + '\n')
            readCountsDict[region]['muCminus'] = DD[region]['muCminus']
            readCountsDict[region]['muSminus'] = DD[region]['muSminus']
            readCountsDict[region]['alphaCminus'] = DD[region]['alphaCminus']
            readCountsDict[region]['alphaSminus'] = DD[region]['alphaSminus']
            readCountsDict[region]['alphaCminus_bse'] = DD[region]['alphaCminus_bse']
            readCountsDict[region]['alphaSminus_bse'] = DD[region]['alphaSminus_bse']
            readCountsDict[region]['muCplus'] = DD[region]['muCplus']
            readCountsDict[region]['muSplus'] = DD[region]['muSplus']
            readCountsDict[region]['alphaCplus'] = DD[region]['alphaCplus']
            readCountsDict[region]['alphaSplus'] = DD[region]['alphaSplus']
            readCountsDict[region]['alphaCplus_bse'] = DD[region]['alphaCplus_bse']
            readCountsDict[region]['alphaSplus_bse'] = DD[region]['alphaSplus_bse']


    outfile.close()

    print 'moderating dispersion estimates'

    muEstimatesSTARRList = []
    muEstimatesControlList = []
    alphaEstimatesControlList = []
    alphaEstimatesSTARRList = []

    print len(readCountsDict.keys())

    for region in readCountsDict.keys():
        plusSNmean = np.mean(readCountsDict[region]['plusSN'])
        plusCNmean = np.mean(readCountsDict[region]['plusCN'])
        minusSNmean = np.mean(readCountsDict[region]['minusSN'])
        minusCNmean = np.mean(readCountsDict[region]['minusCN'])
        if plusSNmean == 0:
            pass
        elif math.fabs((plusSNmean - readCountsDict[region]['muSplus'])/plusSNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaSplus_bse'] != 'nan':
                muEstimatesSTARRList.append(readCountsDict[region]['muSplus'])
                alphaEstimatesSTARRList.append(readCountsDict[region]['alphaSplus'])
        if plusCNmean == 0:
            pass
        elif math.fabs((plusCNmean - readCountsDict[region]['muCplus'])/plusCNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaCplus_bse'] != 'nan':
                muEstimatesControlList.append(readCountsDict[region]['muCplus'])
                alphaEstimatesControlList.append(readCountsDict[region]['alphaCplus'])
        if minusSNmean == 0:
            pass
        elif math.fabs((minusSNmean - readCountsDict[region]['muSminus'])/minusSNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaSminus_bse'] != 'nan':
                muEstimatesSTARRList.append(readCountsDict[region]['muSminus'])
                alphaEstimatesSTARRList.append(readCountsDict[region]['alphaSminus'])
        if minusCNmean == 0:
            pass
        elif math.fabs((minusCNmean - readCountsDict[region]['muCminus'])/minusCNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaCminus_bse'] != 'nan':
                muEstimatesControlList.append(readCountsDict[region]['muCminus'])
                alphaEstimatesControlList.append(readCountsDict[region]['alphaCminus'])

#    minmuLog2ScaleSTARR = np.log2(min(muEstimatesSTARRList))
#    maxmuLog2ScaleSTARR = np.log2(max(muEstimatesSTARRList))
#    incrementSTARR = (maxmuLog2ScaleSTARR - minmuLog2ScaleSTARR)/dispBins

#    print 'min, max, increment, STARR', minmuLog2ScaleSTARR, maxmuLog2ScaleSTARR, incrementSTARR, len(muEstimatesSTARRList), len(muEstimatesSTARRList)

    DecilesDictSTARR = {}
    for i in range(1,dispBins+1):
        DecilesDictSTARR[i] = {}
        DecilesDictSTARR[i]['list'] = []

    muEstimatesSTARRList2 = []
    for mu in muEstimatesSTARRList:
        if str(mu) != 'nan':
            muEstimatesSTARRList2.append(round(mu,1))

    muEstimatesSTARRSet = list(Set(muEstimatesSTARRList2))
    muEstimatesSTARRSet.sort()

    print len(muEstimatesSTARRList2), len(muEstimatesSTARRSet), muEstimatesSTARRSet

    incrS = len(muEstimatesSTARRSet)/(dispBins + 0.0)
    for i in range(len(muEstimatesSTARRList)):
        mu = muEstimatesSTARRList[i]
        alpha = alphaEstimatesSTARRList[i]
        for j in range(1,dispBins+1):
            if mu >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and mu < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                DecilesDictSTARR[j]['list'].append(alpha)
#                print i, j, mu, int(round((j-1)*incr)), min(int(round((j)*incr)),len(muEstimatesSTARRSet)-1), muEstimatesSTARRSet[int(round((j-1)*incr))], muEstimatesSTARRSet[min(int(round((j)*incr)),len(muEstimatesSTARRSet)-1)]
                break

    print 'STARR dispersion parameters:'

    for i in range(1,dispBins+1):
        print i, muEstimatesSTARRSet[int(round((i-1)*incrS))], len(DecilesDictSTARR[i]['list']), np.mean(DecilesDictSTARR[i]['list']), np.std(DecilesDictSTARR[i]['list']), np.median(DecilesDictSTARR[i]['list']), min(DecilesDictSTARR[i]['list']), max(DecilesDictSTARR[i]['list'])

#    minmuLog2ScaleControl = np.log2(min(muEstimatesControlList))
#    maxmuLog2ScaleControl = np.log2(max(muEstimatesControlList))
#    incrementControl = (maxmuLog2ScaleControl - minmuLog2ScaleControl)/dispBins

#    print 'min, max, increment, Control', minmuLog2ScaleControl, maxmuLog2ScaleControl, incrementControl, len(muEstimatesControlList), len(muEstimatesControlList)

    DecilesDictControl = {}
    for i in range(1,dispBins+1):
        DecilesDictControl[i] = {}
        DecilesDictControl[i]['list'] = []

    muEstimatesControlList2 = []
    for mu in muEstimatesControlList:
        if str(mu) != 'nan':
            muEstimatesControlList2.append(round(mu,1))

    muEstimatesControlSet = list(Set(muEstimatesControlList))
    muEstimatesControlSet.sort()
    incrC = len(muEstimatesControlSet)/(dispBins + 0.0)

    for i in range(len(muEstimatesControlList)):
        mu = muEstimatesControlList[i]
        alpha = alphaEstimatesControlList[i]
        for j in range(1,dispBins+1):
            if mu >= muEstimatesControlSet[int(round((j-1)*incrC))] and mu < muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                DecilesDictControl[j]['list'].append(alpha)
                break

    print 'Control disperson parameters:'

#    for i in range(1,dispBins+1):
#        print i, minmuLog2ScaleControl, minmuLog2ScaleControl + (i-1)*incrementControl, maxmuLog2ScaleControl, len(DecilesDictControl[i]['list']), np.mean(DecilesDictControl[i]['list']), np.median(DecilesDictControl[i]['list']), min(DecilesDictControl[i]['list']), max(DecilesDictControl[i]['list'])
    for i in range(1,dispBins+1):
        print i, muEstimatesControlSet[int(round((i-1)*incrC))], len(DecilesDictControl[i]['list']), np.mean(DecilesDictControl[i]['list']), np.std(DecilesDictControl[i]['list']), np.median(DecilesDictControl[i]['list']), min(DecilesDictControl[i]['list']), max(DecilesDictControl[i]['list'])

    for region in readCountsDict.keys():
        plusSNmean = np.mean(readCountsDict[region]['plusSN'])
        plusCNmean = np.mean(readCountsDict[region]['plusCN'])
        minusSNmean = np.mean(readCountsDict[region]['minusSN'])
        minusCNmean = np.mean(readCountsDict[region]['minusCN'])

        plusSandCMean = np.mean([plusSNmean,plusCNmean])
        minusSandCMean = np.mean([minusSNmean,minusCNmean])

        if plusSandCMean <= min(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[1]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[1]['list'])
        elif plusSandCMean <= max(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if plusSandCMean >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and plusSandCMean < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealphaSTARR = np.mean(DecilesDictSTARR[j]['list'])
                    decilealphaSTARR_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        if plusSandCMean <= min(muEstimatesControlSet):
            decilealphaControl = np.mean(DecilesDictControl[1]['list'])
            decilealphaControl_std = np.std(DecilesDictControl[1]['list'])
        elif plusSandCMean >= max(muEstimatesControlSet):
            decilealphaControl = np.mean(DecilesDictControl[dispBins]['list'])
            decilealphaControl_std = np.std(DecilesDictControl[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if plusSandCMean >= muEstimatesControlSet[int(round((j-1)*incrC))] and plusSandCMean < muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                    decilealphaControl = np.mean(DecilesDictControl[j]['list'])
                    decilealphaControl_std = np.mean(DecilesDictControl[j]['list'])
                    break

        readCountsDict[region]['alphaSandCplus_decile_STARR'] = BayesianUpdateMu(readCountsDict[region]['alphaSplus'],decilealphaSTARR,decilealphaSTARR_std)
        readCountsDict[region]['alphaSandCplus_decile_Control'] = BayesianUpdateMu(readCountsDict[region]['alphaCplus'],decilealphaControl,decilealphaControl_std)
        readCountsDict[region]['plusSandCMean'] = plusSandCMean

        logmean = np.log2(minusSandCMean)
        if minusSandCMean <= min(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[1]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[1]['list'])
        elif minusSandCMean >= max(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if minusSandCMean >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and minusSandCMean < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealphaSTARR = np.mean(DecilesDictSTARR[j]['list'])
                    decilealphaSTARR_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        logmean = np.log2(minusSandCMean)
        if minusSandCMean <= min(muEstimatesControlSet):
            decilealphaControl = np.mean(DecilesDictControl[1]['list'])
            decilealphaControl_std = np.std(DecilesDictControl[1]['list'])
        elif minusSandCMean >= max(muEstimatesControlSet):
            decilealphaControl = np.mean(DecilesDictControl[dispBins]['list'])
            decilealphaControl_std = np.std(DecilesDictControl[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if minusSandCMean >= muEstimatesControlSet[int(round((j-1)*incrC))] and minusSandCMean < muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                    decilealphaControl = np.mean(DecilesDictControl[j]['list'])
                    decilealphaControl_std = np.std(DecilesDictControl[j]['list'])
                    break

        readCountsDict[region]['alphaSandCminus_decile_STARR'] = BayesianUpdateMu(readCountsDict[region]['alphaSminus'],decilealphaSTARR,decilealphaSTARR_std)
        readCountsDict[region]['alphaSandCminus_decile_Control'] = BayesianUpdateMu(readCountsDict[region]['alphaCminus'],decilealphaControl,decilealphaControl_std)
        readCountsDict[region]['minusSandCMean'] = minusSandCMean

        if plusSNmean <= min(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[1]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[1]['list'])
        elif plusSNmean >= max(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if plusSNmean > muEstimatesSTARRSet[int(round((j-1)*incrS))] and plusSNmean <= muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealpha = np.mean(DecilesDictSTARR[j]['list'])
                    decilealpha_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        readCountsDict[region]['alphaSplus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaSplus'],decilealpha,decilealpha_std)

        if plusCNmean <= min(muEstimatesControlSet):
            decilealpha = np.mean(DecilesDictControl[1]['list'])
            decilealpha_std = np.std(DecilesDictControl[1]['list'])
        elif plusCNmean >= max(muEstimatesControlSet):
            decilealpha = np.mean(DecilesDictControl[dispBins]['list'])
            decilealpha_std = np.std(DecilesDictControl[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if plusCNmean > muEstimatesControlSet[int(round((j-1)*incrC))] and plusCNmean <= muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                    decilealpha = np.mean(DecilesDictControl[j]['list'])
                    decilealpha_std = np.std(DecilesDictControl[j]['list'])
                    break
        readCountsDict[region]['alphaCplus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaCplus'],decilealpha,decilealpha_std)

        if minusSNmean <= min(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[1]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[1]['list'])
        elif minusSNmean >= max(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if minusSNmean > muEstimatesSTARRSet[int(round((j-1)*incrS))] and minusSNmean <= muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealpha = np.mean(DecilesDictSTARR[j]['list'])
                    decilealpha_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        readCountsDict[region]['alphaSminus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaSminus'],decilealpha,decilealpha_std)

        if minusCNmean <= min(muEstimatesControlSet):
            decilealpha = np.mean(DecilesDictControl[1]['list'])
            decilealpha_std = np.std(DecilesDictControl[1]['list'])
        elif minusCNmean >= max(muEstimatesControlSet):
            decilealpha = np.mean(DecilesDictControl[dispBins]['list'])
            decilealpha_std = np.std(DecilesDictControl[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if minusCNmean > muEstimatesControlSet[int(round((j-1)*incrC))] and minusCNmean <= muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                    decilealpha = np.mean(DecilesDictControl[j]['list'])
                    decilealpha_std = np.std(DecilesDictControl[j]['list'])
                    break
        readCountsDict[region]['alphaCminus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaCminus'],decilealpha,decilealpha_std)

    print 'sampling NB distribution and estimating significant changes'

    NRCArray2 = []
    regionlist = readCountsDict.keys()
    regionlist.sort()
    k = len(regionlist)/NP

    j=0
    for i in range(NP):
        NRCDict = {}
        PvalueDict = {}
        if i+1 == NP:
            while j < len(regionlist):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        else:
            while j < k*(i+1):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        NRCArray2.append((NRCDict,PvalueDict,MFC,NSamplings))

    p = Pool(NP)
    PvalueDictDicts = p.map(sampling_p_values, NRCArray2)

    outfile = open(outprefix + '.final_results','w')

    outline = '#region'
    for ID in STARRFieldIDs:
        outline = outline + '\t' + str(ID) + '|S|+'
    for ID in ControlFieldIDs:
        outline = outline + '\t' + str(ID) + '|C|+'
    for ID in STARRFieldIDs:
        outline = outline + '\t' + str(ID) + '|S|-'
    for ID in ControlFieldIDs:
        outline = outline + '\t' + str(ID) + '|C|-'
    outline =  outline + '\tplus_logFC\tp_val\tminus_logFC\tp_val\tabs_plusminus_logFC\tp_val'
    outfile.write(outline + '\n')

    for PP in PvalueDictDicts:
        for region in PP.keys():
            (p_plus,p_minus,pm,PlusMinusRatio,PlusFC,MinusFC) = PP[region]
            outline = region
            i = 0
            for ID in STARRFieldIDs:
                outline = outline + '\t' + str(readCountsDict[region]['plusSN'][i])
                i+=1
            i = 0
            for ID in ControlFieldIDs: 
                outline = outline + '\t' + str(readCountsDict[region]['plusCN'][i])
                i+=1
            i = 0
            for ID in STARRFieldIDs:
                outline = outline + '\t' + str(readCountsDict[region]['minusSN'][i])
                i+=1
            i = 0
            for ID in ControlFieldIDs: 
                outline = outline + '\t' + str(readCountsDict[region]['minusCN'][i])
                i+=1
            outline = outline + '\t' + str(np.log2(PlusFC))
            outline = outline + '\t' + str(p_plus)
            outline = outline + '\t' + str(np.log2(MinusFC))
            outline = outline + '\t' + str(p_minus)
            if PlusMinusRatio >= 1:
                outline = outline + '\t' + str(np.log2(PlusMinusRatio))
            if PlusMinusRatio < 1:
                outline = outline + '\t' + str(np.log2(1/PlusMinusRatio))
            outline = outline + '\t' + str(pm)
            outfile.write(outline + '\n')

    outfile.close()

run()

