##################################
#                                #
# Last modified 2017/06/05       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import math
import string
from sets import Set
import os
from multiprocessing import Pool
from threading import Thread
import numpy as np
from patsy import dmatrices, dmatrix, demo_data
from scipy.stats import nbinom
# from statsmodels import robust
from statsmodels.discrete.discrete_model import NegativeBinomial
import statsmodels.api as sm

def BayesianUpdateMu(alpha,alpha_mean,alpha_std):

    return (alpha/alpha_std**2 + alpha_mean/alpha_std**2)/(1/alpha_std**2 + 1/alpha_std**2)

def _ll_nb2(y, X, beta, alph):
    mu = np.exp(np.dot(X, beta))
    size = 1/alph
    prob = size/(size+mu)
    ll = nbinom.logpmf(y, size, prob)
    return ll

def estimate_dispersion((NRCDict,DispersionDict)):

    FFFF = 0
    for region in NRCDict.keys():
        FFFF += 1
        if FFFF % 1000 == 0:
            print FFFF, 'regions processed'
        minusCC = []
        minusCCdummy = []
        minusSS = []
        minusSSdummy = []
        for normCounts in NRCDict[region]['minusCN']:
            if len(minusCC) == 0 and max(NRCDict[region]['minusCN']) == 0:
                minusCC.append(1)
                minusCCdummy.append(1)
            else:
                minusCC.append(normCounts)
                minusCCdummy.append(1)
        for normCounts in NRCDict[region]['minusSN']:
            if len(minusSS) == 0 and max(NRCDict[region]['minusSN']) == 0:
                minusSS.append(1)
                minusSSdummy.append(1)
            else:
                minusSS.append(normCounts)
                minusSSdummy.append(1)
        plusCC = []
        plusCCdummy = []
        plusSS = []
        plusSSdummy = []
        for normCounts in NRCDict[region]['plusCN']:
            if len(plusCC) == 0 and max(NRCDict[region]['plusCN']) == 0:
                plusCC.append(1)
                plusCCdummy.append(1)
            else:
                plusCC.append(normCounts)
                plusCCdummy.append(1)
        for normCounts in NRCDict[region]['plusSN']:
            if len(plusSS) == 0 and max(NRCDict[region]['plusSN']) == 0:
                plusSS.append(1)
                plusSSdummy.append(1)
            else:
                plusSS.append(normCounts)
                plusSSdummy.append(1)

        mod = NegativeBinomial(minusSS,minusSSdummy)
        res = mod.fit(disp=0)
        muSSminus = np.exp(res.params[0])
        alphaSSminus = res.params[1]
        try:
            alphaSSminus_bse = res.bse[1]
        except:
            alphaSSminus_bse = 'nan'

        mod = NegativeBinomial(plusSS,plusSSdummy)
        res = mod.fit(disp=0)
        muSSplus = np.exp(res.params[0])
        alphaSSplus = res.params[1]
        try:
            alphaSSplus_bse = res.bse[1]
        except:
            alphaSSplus_bse = 'nan'

        if len(minusCC) == 1:
            muCCminus = minusCC[0]
            muCCplus = plusCC[0]
            alphaCCminus = 'to_be_assigned'
            alphaCCplus = 'to_be_assigned'
            alphaCCminus_bse = 'to_be_assigned'
            alphaCCplus_bse = 'to_be_assigned'
        else:
            mod = NegativeBinomial(minusCC,minusCCdummy)
            res = mod.fit(disp=0)
            muCCminus = np.exp(res.params[0])
            alphaCCminus = res.params[1]
            try:
                alphaCCminus_bse = res.bse[1]
            except:
                alphaCCminus_bse = 'nan'
            mod = NegativeBinomial(plusCC,plusCCdummy)
            res = mod.fit(disp=0)
            muCCplus = np.exp(res.params[0])
            alphaCCplus = res.params[1]
            try:
                alphaCCplus_bse = res.bse[1]
            except:
                alphaCCplus_bse = 'nan'

        DispersionDict[region] = {'muCminus':muCCminus, 'alphaCminus':alphaCCminus, 'alphaCminus_bse':alphaCCminus_bse, 
                                  'muCplus':muCCplus, 'alphaCplus':alphaCCplus, 'alphaCplus_bse':alphaCCplus_bse, 
                                  'muSminus':muSSminus, 'alphaSminus':alphaSSminus, 'alphaSminus_bse':alphaSSminus_bse, 
                                  'muSplus':muSSplus, 'alphaSplus':alphaSSplus, 'alphaSplus_bse':alphaSSplus_bse}

    return DispersionDict

def sampling_p_values((NRCDict,PvalueDict,MFC,NSamplings)):

    FFFF = 0
    for region in NRCDict.keys():
        FFFF += 1
        if FFFF % 100 == 0:
            print FFFF, 'regions processed'
        plusSNmean = max(np.mean(NRCDict[region]['plusSN']),1.)
        plusCNmean = max(np.mean(NRCDict[region]['plusCN']),1.)
        minusSNmean = max(np.mean(NRCDict[region]['minusSN']),1.)
        minusCNmean = max(np.mean(NRCDict[region]['minusCN']),1.)
        alphaSplus_decile = NRCDict[region]['alphaSplus_decile']
        alphaCplus_decile = NRCDict[region]['alphaCplus_decile']
        alphaSminus_decile = NRCDict[region]['alphaSminus_decile']
        alphaCminus_decile = NRCDict[region]['alphaCminus_decile']
        alphaSandCplus_decile_STARR = NRCDict[region]['alphaSandCplus_decile_STARR']
        alphaSandCplus_decile_Control = NRCDict[region]['alphaSandCplus_decile_Control']
        plusSandCMean = max(NRCDict[region]['plusSandCMean'],1.)
        alphaSandCminus_decile_STARR = NRCDict[region]['alphaSandCminus_decile_STARR']
        alphaSandCminus_decile_Control = NRCDict[region]['alphaSandCminus_decile_Control']
        minusSandCMean = max(NRCDict[region]['minusSandCMean'],1.)
               
        r = 1./alphaSandCplus_decile_STARR
        p = plusSandCMean*alphaSandCplus_decile_STARR/(1 + plusSandCMean*alphaSandCplus_decile_STARR)
        plusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1./alphaSandCplus_decile_Control
        p = plusSandCMean*alphaSandCplus_decile_Control/(1 + plusSandCMean*alphaSandCplus_decile_Control)
        plusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1./alphaSandCminus_decile_STARR
        p = minusSandCMean*alphaSandCminus_decile_STARR/(1 + minusSandCMean*alphaSandCminus_decile_STARR)
        minusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        r = 1./alphaSandCminus_decile_Control
        p = minusSandCMean*alphaSandCminus_decile_Control/(1 + minusSandCMean*alphaSandCminus_decile_Control)
        minusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

        PlusFC = (plusSNmean + 1.)/(plusCNmean + 1.)
        MinusFC = (minusSNmean + 1.)/(minusCNmean + 1.)
        PlusMinusRatio = PlusFC/MinusFC
        meanPlusMinusFC = np.mean([PlusFC,MinusFC])

        if PlusFC < MFC and MinusFC < MFC:
            PvalueDict[region] = ('<MFC','<MFC','<MFC',PlusMinusRatio,PlusFC,MinusFC)
        else:
            r = 1/alphaSplus_decile
            p = plusCNmean*meanPlusMinusFC*alphaSplus_decile/(1 + plusCNmean*meanPlusMinusFC*alphaSplus_decile)
            PMplusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

            r = 1/alphaCplus_decile	
            p = plusCNmean*alphaCplus_decile/(1 + plusCNmean*alphaCplus_decile)
            PMplusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

            r = 1/alphaSminus_decile
            p = minusCNmean*meanPlusMinusFC*alphaSminus_decile/(1 + minusCNmean*meanPlusMinusFC*alphaSminus_decile)
            PMminusSNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

            r = 1/alphaCminus_decile
            p = minusCNmean*alphaCminus_decile/(1 + minusCNmean*alphaCminus_decile)
            PMminusCNsamples = list(np.random.negative_binomial(r,1-p,size=NSamplings))

            plusMFCpass = 0.0
            minusMFCpass = 0.0
            pmratiopass = 0.0
            for i in range(NSamplings):
                plusRatio = (plusSNsamples[i] + 1.)/(plusCNsamples[i] + 1.)
                if plusRatio > PlusFC:
                    plusMFCpass += 1
                minusRatio = (minusSNsamples[i] + 1.)/(minusCNsamples[i] + 1.)
                if minusRatio > MinusFC:
                    minusMFCpass += 1
            p_plus = plusMFCpass/NSamplings
            p_minus = minusMFCpass/NSamplings
            if PlusFC == MinusFC:
                pm = 1
            else:
                for i in range(NSamplings):
                    pmratio = ((PMplusSNsamples[i] + 1.)/(PMplusCNsamples[i] + 1.))/((PMminusSNsamples[i] + 1.)/(PMminusCNsamples[i] + 1.))
                    if PlusMinusRatio > 1 and pmratio > PlusMinusRatio:
                        pmratiopass += 1
                    if PlusMinusRatio < 1 and pmratio < PlusMinusRatio:
                        pmratiopass += 1
                pm = pmratiopass/NSamplings

            PvalueDict[region] = (p_plus,p_minus,pm,PlusMinusRatio,PlusFC,MinusFC)

    return PvalueDict

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s fragment_counts STARRfields Controlfields window minFoldChange sampling_size outprefix [-norm median-of-ratios] [-p threads] [-dispBins number]' % sys.argv[0]
        print '\tNote: the script assumes the following format of the fragment counts file:'
        print '\t\t\tchr10:102658058-102659058|+  for plus strand fragments'
        print '\t\t\tchr10:102658058-102659058|-  for minus strand fragments'
        print '\tthe fragment counts file can be zipped'
        print '\tNote: STARRfields Controlfields should be comma separated'
        sys.exit(1)

    readCounts = sys.argv[1]
    STARRFieldIDs = []
    fields = sys.argv[2].split(',')
    for ID in fields:
        STARRFieldIDs.append(int(ID))
    ControlFieldIDs = []
    fields = sys.argv[3].split(',')
    for ID in fields:
        ControlFieldIDs.append(int(ID))
    WS = int(sys.argv[4])
    MFC = float(sys.argv[5])
    NSamplings = int(sys.argv[6])
    outprefix = sys.argv[7]

    dispBins = 10
    if '-dispBins' in sys.argv:
        dispBins = int(sys.argv[sys.argv.index('-dispBins') + 1])

    doLDS = True
    if '-norm' in sys.argv:
        if sys.argv[sys.argv.index('-norm') + 1] == 'median-of-ratios':
            print 'will apply the median-of-ratios normalization'
            doLDS = False

    NP = 1
    if '-p' in sys.argv:
        NP = int(sys.argv[sys.argv.index('-p') + 1])
        print 'will run on', NP, 'threads'

    TotalReadCountDict = {}
    for ID in STARRFieldIDs:
        TotalReadCountDict[ID] = 0
    for ID in ControlFieldIDs:
        TotalReadCountDict[ID] = 0

    fragmentsDict = {}
    readCountsDict = {}

    if readCounts.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + readCounts
    elif readCounts.endswith('.gz'):
        cmd = 'gunzip -c ' + readCounts
    elif readCounts.endswith('.zip'):
        cmd = 'unzip -p ' + readCounts
    else:
        cmd = 'cat ' + readCounts
    p = os.popen(cmd, "r")
    line = 'line'
    RP = 0
    while line != '':
        line = p.readline().strip()
        fields = line.split('\t')
        RP += 1
        if RP % 10000 == 0:
            print RP, 'fragments processed'
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        for ID in STARRFieldIDs:
            TotalReadCountDict[ID] += float(fields[ID])
        for ID in ControlFieldIDs:
            TotalReadCountDict[ID] += float(fields[ID])
        region = fields[0].split('|')[0]
        chr = region.split(':')[0]
        left = int(region.split(':')[1].split('-')[0])
        L = left - left % WS
        right = int(region.split(':')[1].split('-')[1])
        R = right - right % WS
        strand = fields[0].split('|')[1]
        region = chr + ':' + str(L) + '-' + str(R)
        if readCountsDict.has_key(region):
            pass
        else:
            readCountsDict[region] = {}
            readCountsDict[region]['minusC'] = []
            readCountsDict[region]['minusS'] = []
            readCountsDict[region]['plusC'] = []
            readCountsDict[region]['plusS'] = []
            readCountsDict[region]['minusCN'] = []
            readCountsDict[region]['minusSN'] = []
            readCountsDict[region]['plusCN'] = []
            readCountsDict[region]['plusSN'] = []
            for ID in STARRFieldIDs:
                readCountsDict[region]['minusS'].append(0)
                readCountsDict[region]['plusS'].append(0)
            for ID in ControlFieldIDs:
                readCountsDict[region]['minusC'].append(0)
                readCountsDict[region]['plusC'].append(0)
        if strand == '+':
            i=0
            for ID in STARRFieldIDs:
                readCountsDict[region]['plusS'][i] += (float(fields[ID]))
                i+=1
            i=0
            for ID in ControlFieldIDs:
                readCountsDict[region]['plusC'][i] += (float(fields[ID]))
                i+=1
        if strand == '-':
            i=0
            for ID in STARRFieldIDs:
                readCountsDict[region]['minusS'][i] += (float(fields[ID]))
                i+=1
            i=0
            for ID in ControlFieldIDs:
                readCountsDict[region]['minusC'][i] += (float(fields[ID]))
                i+=1

    print len(readCountsDict), 'regions found'

    TRClist = []
    for ID in TotalReadCountDict.keys():
        TRClist.append(TotalReadCountDict[ID])

    MinimumTotalReadCounts = min(TRClist)

    LDSratioDict = {}
    for ID in TotalReadCountDict.keys():
        LDSratioDict[ID] = TotalReadCountDict[ID]/(MinimumTotalReadCounts + 0.0)

    if doLDS:
        for region in readCountsDict.keys():
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['plusSN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusS'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['minusSN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['plusCN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusC'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['minusCN'].append(int(round(normCounts)))
                i += 1
    else:
        MedianRatiosDict = {}
        for ID in TotalReadCountDict.keys():
            MedianRatiosDict[ID] = []
        for region in readCountsDict.keys():
            i = 0
            n = 0
            Prod = 1
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            if Prod == 0:
                continue
            else:
                ProdNthroot = math.pow(Prod,(1./n))
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                i += 1
                MedianRatiosDict[ID].append(counts/ProdNthroot)
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                MedianRatiosDict[ID].append(counts/ProdNthroot)
                i += 1
            i = 0
            n = 0
            Prod = 1
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['minusS'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['minusC'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            if Prod == 0:
                continue
            else:
                ProdNthroot = math.pow(Prod,(1./n))
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['minusS'][i]
                i += 1
                MedianRatiosDict[ID].append(counts/ProdNthroot)
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['minusC'][i]
                MedianRatiosDict[ID].append(counts/ProdNthroot)
                i += 1
        MMratioDict = {}
        for ID in TotalReadCountDict.keys():
            MMratioDict[ID] = np.median(np.array(MedianRatiosDict[ID]))
        for region in readCountsDict.keys():
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['plusS'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['plusSN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusS'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['minusSN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['plusC'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['plusCN'].append(int(round(normCounts)))
                counts = readCountsDict[region]['minusC'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['minusCN'].append(int(round(normCounts)))
                i += 1
            
    print 'finished read count normalization'

    print 'estimating dispersion'

    NRCArray = []
    regionlist = readCountsDict.keys()
    regionlist.sort()
    k = len(regionlist)/NP

    j=0
    for i in range(NP):
        NRCDict = {}
        DispersionDict = {}
        if i+1 == NP:
            while j < len(regionlist):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        else:
            while j < k*(i+1):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        NRCArray.append((NRCDict,DispersionDict))

    p = Pool(NP)
    DispersionDicts = p.map(estimate_dispersion, NRCArray)

    print 'finihed dispersion estimation'

    DispersionDict = {}

    for DD in DispersionDicts:
        for region in DD.keys():
            readCountsDict[region]['muCminus'] = DD[region]['muCminus']
            readCountsDict[region]['muSminus'] = DD[region]['muSminus']
            readCountsDict[region]['alphaCminus'] = DD[region]['alphaCminus']
            readCountsDict[region]['alphaSminus'] = DD[region]['alphaSminus']
            readCountsDict[region]['alphaCminus_bse'] = DD[region]['alphaCminus_bse']
            readCountsDict[region]['alphaSminus_bse'] = DD[region]['alphaSminus_bse']
            readCountsDict[region]['muCplus'] = DD[region]['muCplus']
            readCountsDict[region]['muSplus'] = DD[region]['muSplus']
            readCountsDict[region]['alphaCplus'] = DD[region]['alphaCplus']
            readCountsDict[region]['alphaSplus'] = DD[region]['alphaSplus']
            readCountsDict[region]['alphaCplus_bse'] = DD[region]['alphaCplus_bse']
            readCountsDict[region]['alphaSplus_bse'] = DD[region]['alphaSplus_bse']

    print 'moderating dispersion estimates'

    muEstimatesSTARRList = []
    muEstimatesControlList = []
    alphaEstimatesControlList = []
    alphaEstimatesSTARRList = []

    print len(readCountsDict.keys())

    for region in readCountsDict.keys():
        plusSNmean = np.mean(readCountsDict[region]['plusSN'])
        plusCNmean = np.mean(readCountsDict[region]['plusCN'])
        minusSNmean = np.mean(readCountsDict[region]['minusSN'])
        minusCNmean = np.mean(readCountsDict[region]['minusCN'])
        if plusSNmean == 0:
            pass
        elif math.fabs((plusSNmean - readCountsDict[region]['muSplus'])/plusSNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaSplus_bse'] != 'nan':
                muEstimatesSTARRList.append(readCountsDict[region]['muSplus'])
                alphaEstimatesSTARRList.append(readCountsDict[region]['alphaSplus'])
        if plusCNmean == 0:
            pass
        elif len(ControlFieldIDs) == 1:
            pass
        elif math.fabs((plusCNmean - readCountsDict[region]['muCplus'])/plusCNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaCplus_bse'] != 'nan':
                muEstimatesControlList.append(readCountsDict[region]['muCplus'])
                alphaEstimatesControlList.append(readCountsDict[region]['alphaCplus'])
        if minusSNmean == 0:
            pass
        elif math.fabs((minusSNmean - readCountsDict[region]['muSminus'])/minusSNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaSminus_bse'] != 'nan':
                muEstimatesSTARRList.append(readCountsDict[region]['muSminus'])
                alphaEstimatesSTARRList.append(readCountsDict[region]['alphaSminus'])
        if minusCNmean == 0:
            pass
        elif len(ControlFieldIDs) == 1:
            pass
        elif math.fabs((minusCNmean - readCountsDict[region]['muCminus'])/minusCNmean) > 0.05:
            pass
        else:
            if readCountsDict[region]['alphaCminus_bse'] != 'nan':
                muEstimatesControlList.append(readCountsDict[region]['muCminus'])
                alphaEstimatesControlList.append(readCountsDict[region]['alphaCminus'])

    DecilesDictSTARR = {}
    for i in range(1,dispBins+1):
        DecilesDictSTARR[i] = {}
        DecilesDictSTARR[i]['list'] = []

    muEstimatesSTARRList2 = []
    for mu in muEstimatesSTARRList:
        if str(mu) != 'nan':
            muEstimatesSTARRList2.append(round(mu,1))

    muEstimatesSTARRSet = list(Set(muEstimatesSTARRList2))
    muEstimatesSTARRSet.sort()

    print len(muEstimatesSTARRList2), len(muEstimatesSTARRSet), muEstimatesSTARRSet

    incrS = len(muEstimatesSTARRSet)/(dispBins + 0.0)
    for i in range(len(muEstimatesSTARRList)):
        mu = muEstimatesSTARRList[i]
        alpha = alphaEstimatesSTARRList[i]
        for j in range(1,dispBins+1):
            if mu >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and mu < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                DecilesDictSTARR[j]['list'].append(alpha)
                break

    print 'STARR dispersion parameters:'

    for i in range(1,dispBins+1):
        print i, muEstimatesSTARRSet[int(round((i-1)*incrS))], len(DecilesDictSTARR[i]['list']), np.mean(DecilesDictSTARR[i]['list']), np.std(DecilesDictSTARR[i]['list']), np.median(DecilesDictSTARR[i]['list']), min(DecilesDictSTARR[i]['list']), max(DecilesDictSTARR[i]['list'])

    if len(ControlFieldIDs) == 1:
        pass
    else:
        DecilesDictControl = {}
        for i in range(1,dispBins+1):
            DecilesDictControl[i] = {}
            DecilesDictControl[i]['list'] = []

        muEstimatesControlList2 = []
        for mu in muEstimatesControlList:
            if str(mu) != 'nan':
                muEstimatesControlList2.append(round(mu,1))

        muEstimatesControlSet = list(Set(muEstimatesControlList))
        muEstimatesControlSet.sort()
        incrC = len(muEstimatesControlSet)/(dispBins + 0.0)

        for i in range(len(muEstimatesControlList)):
            mu = muEstimatesControlList[i]
            alpha = alphaEstimatesControlList[i]
            for j in range(1,dispBins+1):
                if mu >= muEstimatesControlSet[int(round((j-1)*incrC))] and mu < muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                    DecilesDictControl[j]['list'].append(alpha)
                    break

        print 'Control disperson parameters:'
        for i in range(1,dispBins+1):
            print i, muEstimatesControlet[int(round((i-1)*incrS))], len(DecilesDictControl[i]['list']), np.mean(DecilesDictControl[i]['list']), np.std(DecilesDictControl[i]['list']), np.median(DecilesDictControl[i]['list']), min(DecilesDictControl[i]['list']), max(DecilesDictControl[i]['list'])

    for region in readCountsDict.keys():
        plusSNmean = np.mean(readCountsDict[region]['plusSN'])
        plusCNmean = np.mean(readCountsDict[region]['plusCN'])
        minusSNmean = np.mean(readCountsDict[region]['minusSN'])
        minusCNmean = np.mean(readCountsDict[region]['minusCN'])

        plusSandCMean = np.mean([plusSNmean,plusCNmean])
        minusSandCMean = np.mean([minusSNmean,minusCNmean])

        if plusSandCMean <= min(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[1]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[1]['list'])
        elif plusSandCMean <= max(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if plusSandCMean >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and plusSandCMean < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealphaSTARR = np.mean(DecilesDictSTARR[j]['list'])
                    decilealphaSTARR_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        if len(ControlFieldIDs) == 1:
             if plusSandCMean <= min(muEstimatesSTARRSet):
                 decilealphaControl = np.mean(DecilesDictSTARR[1]['list'])
                 decilealphaControl_std = np.std(DecilesDictSTARR[1]['list'])
             elif plusSandCMean >= max(muEstimatesSTARRSet):
                 decilealphaControl = np.mean(DecilesDictSTARR[dispBins]['list'])
                 decilealphaControl_std = np.std(DecilesDictSTARR[dispBins]['list'])
             else:
                 for j in range(1,dispBins+1):
                     if plusSandCMean >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and plusSandCMean < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                         decilealphaControl = np.mean(DecilesDictSTARR[j]['list'])
                         decilealphaControl_std = np.std(DecilesDictSTARR[j]['list'])
                         break
        else:
             if plusSandCMean <= min(muEstimatesControlSet):
                 decilealphaControl = np.mean(DecilesDictControl[1]['list'])
                 decilealphaControl_std = np.std(DecilesDictControl[1]['list'])
             elif plusSandCMean >= max(muEstimatesControlSet):
                 decilealphaControl = np.mean(DecilesDictControl[dispBins]['list'])
                 decilealphaControl_std = np.std(DecilesDictControl[dispBins]['list'])
             else:
                 for j in range(1,dispBins+1):
                     if plusSandCMean >= muEstimatesControlSet[int(round((j-1)*incrC))] and plusSandCMean < muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                         decilealphaControl = np.mean(DecilesDictControl[j]['list'])
                         decilealphaControl_std = np.std(DecilesDictControl[j]['list'])
                         break

        readCountsDict[region]['alphaSandCplus_decile_STARR'] = BayesianUpdateMu(readCountsDict[region]['alphaSplus'],decilealphaSTARR,decilealphaSTARR_std)
        if len(ControlFieldIDs) == 1:
            readCountsDict[region]['alphaSandCplus_decile_Control'] = decilealphaControl
        else:
            readCountsDict[region]['alphaSandCplus_decile_Control'] = BayesianUpdateMu(readCountsDict[region]['alphaCplus'],decilealphaControl,decilealphaControl_std)
        readCountsDict[region]['plusSandCMean'] = plusSandCMean

#        logmean = np.log2(minusSandCMean)
        if minusSandCMean <= min(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[1]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[1]['list'])
        elif minusSandCMean >= max(muEstimatesSTARRSet):
            decilealphaSTARR = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealphaSTARR_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if minusSandCMean >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and minusSandCMean < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealphaSTARR = np.mean(DecilesDictSTARR[j]['list'])
                    decilealphaSTARR_std = np.std(DecilesDictSTARR[j]['list'])
                    break
#        logmean = np.log2(minusSandCMean)
        if len(ControlFieldIDs) == 1:
            if minusSandCMean <= min(muEstimatesSTARRSet):
                decilealphaControl = np.mean(DecilesDictSTARR[1]['list'])
#                decilealphaControl_std = np.std(DecilesDictSTARR[1]['list'])
            elif minusSandCMean >= max(muEstimatesSTARRSet):
                decilealphaControl = np.mean(DecilesDictSTARR[dispBins]['list'])
#                decilealphaControl_std = np.std(DecilesDictSTARR[dispBins]['list'])
            else:
                for j in range(1,dispBins+1):
                    if minusSandCMean >= muEstimatesSTARRSet[int(round((j-1)*incrS))] and minusSandCMean < muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                        decilealphaControl = np.mean(DecilesDictSTARR[j]['list'])
#                        decilealphaControl_std = np.std(DecilesDictSTARR[j]['list'])
                        break
        else:
            if minusSandCMean <= min(muEstimatesControlSet):
                decilealphaControl = np.mean(DecilesDictControl[1]['list'])
                decilealphaControl_std = np.std(DecilesDictControl[1]['list'])
            elif minusSandCMean >= max(muEstimatesControlSet):
                decilealphaControl = np.mean(DecilesDictControl[dispBins]['list'])
                decilealphaControl_std = np.std(DecilesDictControl[dispBins]['list'])
            else:
                for j in range(1,dispBins+1):
                    if minusSandCMean >= muEstimatesControlSet[int(round((j-1)*incrC))] and minusSandCMean < muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                        decilealphaControl = np.mean(DecilesDictControl[j]['list'])
                        decilealphaControl_std = np.std(DecilesDictControl[j]['list'])
                        break

        readCountsDict[region]['alphaSandCminus_decile_STARR'] = BayesianUpdateMu(readCountsDict[region]['alphaSminus'],decilealphaSTARR,decilealphaSTARR_std)
        if len(ControlFieldIDs) == 1:
            readCountsDict[region]['alphaSandCminus_decile_Control'] = decilealphaControl
        else:
            readCountsDict[region]['alphaSandCminus_decile_Control'] = BayesianUpdateMu(readCountsDict[region]['alphaCminus'],decilealphaControl,decilealphaControl_std)
        readCountsDict[region]['minusSandCMean'] = minusSandCMean

        if plusSNmean <= min(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[1]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[1]['list'])
        elif plusSNmean >= max(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if plusSNmean > muEstimatesSTARRSet[int(round((j-1)*incrS))] and plusSNmean <= muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealpha = np.mean(DecilesDictSTARR[j]['list'])
                    decilealpha_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        readCountsDict[region]['alphaSplus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaSplus'],decilealpha,decilealpha_std)

        if len(ControlFieldIDs) == 1:
            if plusCNmean <= min(muEstimatesSTARRSet):
                decilealpha = np.mean(DecilesDictSTARR[1]['list'])
#                decilealpha_std = np.std(DecilesDictSTARR[1]['list'])
            elif plusCNmean >= max(muEstimatesSTARRSet):
                decilealpha = np.mean(DecilesDictSTARR[dispBins]['list'])
#                decilealpha_std = np.std(DecilesDictSTARR[dispBins]['list'])
            else:
                for j in range(1,dispBins+1):
                    if plusCNmean > muEstimatesSTARRSet[int(round((j-1)*incrS))] and plusCNmean <= muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                        decilealpha = np.mean(DecilesDictSTARR[j]['list'])
#                        decilealpha_std = np.std(DecilesDictSTARR[j]['list'])
                        break
            readCountsDict[region]['alphaCplus_decile'] = decilealpha
        else:
            if plusCNmean <= min(muEstimatesControlSet):
                decilealpha = np.mean(DecilesDictControl[1]['list'])
                decilealpha_std = np.std(DecilesDictControl[1]['list'])
            elif plusCNmean >= max(muEstimatesControlSet):
                decilealpha = np.mean(DecilesDictControl[dispBins]['list'])
                decilealpha_std = np.std(DecilesDictControl[dispBins]['list'])
            else:
                for j in range(1,dispBins+1):
                    if plusCNmean > muEstimatesControlSet[int(round((j-1)*incrC))] and plusCNmean <= muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                        decilealpha = np.mean(DecilesDictControl[j]['list'])
                        decilealpha_std = np.std(DecilesDictControl[j]['list'])
                        break
            readCountsDict[region]['alphaCplus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaCplus'],decilealpha,decilealpha_std)

        if minusSNmean <= min(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[1]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[1]['list'])
        elif minusSNmean >= max(muEstimatesSTARRSet):
            decilealpha = np.mean(DecilesDictSTARR[dispBins]['list'])
            decilealpha_std = np.std(DecilesDictSTARR[dispBins]['list'])
        else:
            for j in range(1,dispBins+1):
                if minusSNmean > muEstimatesSTARRSet[int(round((j-1)*incrS))] and minusSNmean <= muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                    decilealpha = np.mean(DecilesDictSTARR[j]['list'])
                    decilealpha_std = np.std(DecilesDictSTARR[j]['list'])
                    break
        readCountsDict[region]['alphaSminus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaSminus'],decilealpha,decilealpha_std)

        if len(ControlFieldIDs) == 1:
            if minusCNmean <= min(muEstimatesSTARRSet):
                decilealpha = np.mean(DecilesDictSTARR[1]['list'])
#                decilealpha_std = np.std(DecilesDictSTARR[1]['list'])
            elif minusCNmean >= max(muEstimatesSTARRSet):
                decilealpha = np.mean(DecilesDictSTARR[dispBins]['list'])
#                decilealpha_std = np.std(DecilesDictSTARR[dispBins]['list'])
            else:
                for j in range(1,dispBins+1):
                    if minusCNmean > muEstimatesSTARRSet[int(round((j-1)*incrS))] and minusCNmean <= muEstimatesSTARRSet[min(int(round((j)*incrS)),len(muEstimatesSTARRSet)-1)]:
                        decilealpha = np.mean(DecilesDictSTARR[j]['list'])
#                        decilealpha_std = np.std(DecilesDictSTARR[j]['list'])
                        break
            readCountsDict[region]['alphaCminus_decile'] = decilealpha
        else:
            if minusCNmean <= min(muEstimatesControlSet):
                decilealpha = np.mean(DecilesDictControl[1]['list'])
                decilealpha_std = np.std(DecilesDictControl[1]['list'])
            elif minusCNmean >= max(muEstimatesControlSet):
                decilealpha = np.mean(DecilesDictControl[dispBins]['list'])
                decilealpha_std = np.std(DecilesDictControl[dispBins]['list'])
            else:
                for j in range(1,dispBins+1):
                    if minusCNmean > muEstimatesControlSet[int(round((j-1)*incrC))] and minusCNmean <= muEstimatesControlSet[min(int(round((j)*incrC)),len(muEstimatesControlSet)-1)]:
                        decilealpha = np.mean(DecilesDictControl[j]['list'])
                        decilealpha_std = np.std(DecilesDictControl[j]['list'])
                        break
            readCountsDict[region]['alphaCminus_decile'] = BayesianUpdateMu(readCountsDict[region]['alphaCminus'],decilealpha,decilealpha_std)

    outfile = open(outprefix + '.params','w')

    outline = '#region'
    for ID in ControlFieldIDs: 
        outline = outline + '\t' + str(ID)
    for ID in STARRFieldIDs:
        outline = outline + '\t' + str(ID)
    outline =  outline + '\tmuControl\talphaControl\tmuSTARR\talphaSTARR\talphaControl_moderated\talphaSTARR_moderated'
    outfile.write(outline + '\n')

    regionssss = readCountsDict.keys()
    regionssss.sort()
    for region in regionssss:
        outline = region + '|+'
        i = 0
        for ID in STARRFieldIDs:
            outline = outline + '\t' + str(readCountsDict[region]['plusSN'][i])
            i+=1
        i = 0
        for ID in ControlFieldIDs: 
            outline = outline + '\t' + str(readCountsDict[region]['plusCN'][i])
            i+=1
        outline = outline + '\t' + str(readCountsDict[region]['muCplus'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaCplus'])
        outline = outline + '\t' + str(readCountsDict[region]['muSplus'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaSplus'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaCplus_decile'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaSplus_decile'])
        outfile.write(outline + '\n')
        outline = region + '|-'
        i = 0
        for ID in STARRFieldIDs:
            outline = outline + '\t' + str(readCountsDict[region]['minusSN'][i])
            i+=1
        i = 0
        for ID in ControlFieldIDs: 
            outline = outline + '\t' + str(readCountsDict[region]['minusCN'][i])
            i+=1
        outline = outline + '\t' + str(readCountsDict[region]['muCminus'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaCminus'])
        outline = outline + '\t' + str(readCountsDict[region]['muSminus'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaSminus'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaCminus_decile'])
        outline = outline + '\t' + str(readCountsDict[region]['alphaSminus_decile'])
        outfile.write(outline + '\n')

    outfile.close()

    print 'sampling NB distribution and estimating significant changes'

    NRCArray2 = []
    regionlist = readCountsDict.keys()
    regionlist.sort()
    k = len(regionlist)/NP

    j=0
    for i in range(NP):
        NRCDict = {}
        PvalueDict = {}
        if i+1 == NP:
            while j < len(regionlist):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        else:
            while j < k*(i+1):
                region = regionlist[j]
                NRCDict[region] = readCountsDict[region]
                j += 1
        NRCArray2.append((NRCDict,PvalueDict,MFC,NSamplings))

    p = Pool(NP)
    PvalueDictDicts = p.map(sampling_p_values, NRCArray2)

    outfile = open(outprefix + '.final_results','w')

    outline = '#region'
    for ID in STARRFieldIDs:
        outline = outline + '\t' + str(ID) + '|S|+'
    for ID in ControlFieldIDs:
        outline = outline + '\t' + str(ID) + '|C|+'
    for ID in STARRFieldIDs:
        outline = outline + '\t' + str(ID) + '|S|-'
    for ID in ControlFieldIDs:
        outline = outline + '\t' + str(ID) + '|C|-'
    outline =  outline + '\tplus_logFC\tp_val\tminus_logFC\tp_val\tabs_plusminus_logFC\tp_val'
    outfile.write(outline + '\n')

    final_regions = []
    for PP in PvalueDictDicts:
        for region in PP.keys():
            (p_plus,p_minus,pm,PlusMinusRatio,PlusFC,MinusFC) = PP[region]
            final_regions.append((region,p_plus,p_minus,pm,PlusMinusRatio,PlusFC,MinusFC))
    final_regions.sort()
    for (region,p_plus,p_minus,pm,PlusMinusRatio,PlusFC,MinusFC) in final_regions:
        outline = region
        i = 0
        for ID in STARRFieldIDs:
            outline = outline + '\t' + str(readCountsDict[region]['plusSN'][i])
            i+=1
        i = 0
        for ID in ControlFieldIDs: 
            outline = outline + '\t' + str(readCountsDict[region]['plusCN'][i])
            i+=1
        i = 0
        for ID in STARRFieldIDs:
            outline = outline + '\t' + str(readCountsDict[region]['minusSN'][i])
            i+=1
        i = 0
        for ID in ControlFieldIDs: 
            outline = outline + '\t' + str(readCountsDict[region]['minusCN'][i])
            i+=1
        outline = outline + '\t' + str(np.log2(PlusFC))
        outline = outline + '\t' + str(p_plus)
        outline = outline + '\t' + str(np.log2(MinusFC))
        outline = outline + '\t' + str(p_minus)
        if PlusMinusRatio >= 1:
            outline = outline + '\t' + str(np.log2(PlusMinusRatio))
        if PlusMinusRatio < 1:
            outline = outline + '\t' + str(np.log2(1/PlusMinusRatio))
        outline = outline + '\t' + str(pm)
        outfile.write(outline + '\n')

    outfile.close()

run()

