##################################
#                                #
# Last modified 2017/04/16       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import math
import string
from sets import Set
import os
from multiprocessing import Pool
from threading import Thread
import numpy as np
from patsy import dmatrices, dmatrix, demo_data
from scipy.stats import nbinom
from statsmodels.base.model import GenericLikelihoodModel
import statsmodels.api as sm

def _ll_nb2(y, X, beta, alph):
    mu = np.exp(np.dot(X, beta))
    size = 1/alph
    prob = size/(size+mu)
    ll = nbinom.logpmf(y, size, prob)
    return ll

class NBin(GenericLikelihoodModel):
    def __init__(self, endog, exog, **kwds):
        super(NBin, self).__init__(endog, exog, **kwds)

    def nloglikeobs(self, params):
        alph = params[-1]
        beta = params[:-1]
        ll = _ll_nb2(self.endog, self.exog, beta, alph)
        return -ll

    def fit(self, start_params=None, maxiter=10000, maxfun=5000, **kwds):
        # we have one additional parameter and we need to add it for summary
        self.exog_names.append('alpha')
        if start_params == None:
            # Reasonable starting values
            start_params = np.append(np.zeros(self.exog.shape[1]), .5)
            # intercept
            start_params[-2] = np.log(self.endog.mean())
        return super(NBin, self).fit(start_params=start_params,maxiter=maxiter, maxfun=maxfun,**kwds)

def estimate_dispersion((NRCDict,DispersionDict)):

    FFFF = 0

    for fragment in NRCDict.keys():
        FFFF += 1
        if FFFF % 1000 == 0:
            print FFFF, 'fragments processed'
        CC = []
        II = []
        for normCounts in NRCDict[fragment]['CN']:
            if len(CC) == 0 and max(NRCDict[fragment]['CN']) == 0:
                CC.append(1)
            else:
                CC.append(normCounts)
            II.append('C')
        for normCounts in NRCDict[fragment]['SN']:
            if len(CC) == len(NRCDict[fragment]['CN']) and max(NRCDict[fragment]['SN']) == 0:
                CC.append(1)
            else:
                CC.append(normCounts)
            II.append('S')
        if max(II) == 0:
            II[0] = 1
#        print fragment, NRCDict[fragment], CC, II
        NBdata = {'c': CC, 'i': II}
        c, i = dmatrices("c ~ i", NBdata)
        mod = NBin(c, i)
        res = mod.fit(disp=0,iprint=0)
#        except:
#            print fragment, NRCDict[fragment]
#            print 'EXITING BECAUSE OF NUMERICAL ISSUES'
#            sys.exit(1)
        logratio = res.params[1]
        alpha = res.params[2]
        try:
            logratioSE = res.bse[1]
        except:
            logratioSE = 'nan'
        DispersionDict[fragment] = (logratio,logratioSE,alpha)
#        print res.summary()

    return DispersionDict

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s RNA_bam1,RNA_bam2,...,RNA_bamN Input_bam1,Input_bam2,...,Input_bamN readCounts.table outprefix [-singleFieldCoords] [-norm median-of-ratios] [-p threads]' % sys.argv[0]
        print '\t[-norm median-of-ratios] option: use the median-of-ratios-method for read count normalization; default: downsizing to the lowest sequencing depth'
        sys.exit(1)

    STARRBAMFiles = sys.argv[1].split(',')
    ControlBAMFiles = sys.argv[2].split(',')
    readCounts = sys.argv[3]
    outprefix = sys.argv[4]

    doSFC = False
    if '-singleFieldCoords' in sys.argv:
        doSFC = True

    doLDS = True
    if '-norm' in sys.argv:
        if sys.argv[sys.argv.index('-norm') + 1] == 'median-of-ratios':
            doLDS = False

    NP = 1
    if '-p' in sys.argv:
        NP = int(sys.argv[sys.argv.index('-p') + 1])
        print 'will run on', NP, 'threads'

    TotalReadCountDict = {}
    for BAM in STARRBAMFiles:
        TotalReadCountDict[BAM] = 0
    for BAM in ControlBAMFiles:
        TotalReadCountDict[BAM] = 0

    readCountsDict = {}
    linelist = open(readCounts)
    k = 0
    for line in linelist:
        k+=1
        if k % 5000000 == 0:
            print str(k/1000000) + 'M fragments parsed'
        if line.startswith('#'):
            FieldIDDict = {}
            fields = line.strip().split('\t')
            if doSFC:
                for i in range(1,len(fields)):
                    ID = fields[i]
                    FieldIDDict[ID] = i
            else:
                for i in range(4,len(fields)):
                    ID = fields[i]
                    FieldIDDict[ID] = i
            continue
        fields = line.strip().split('\t')
        if doSFC:
            chr = fields[0].split(':')[0]
            left = int(fields[0].split(':')[1].split('|')[0].split('-')[0])
            right = int(fields[0].split(':')[1].split('|')[0].split('-')[1])
            strand = fields[0].split(':')[1].split('|')[1]
        else:
            chr = fields[0]
            left = int(fields[1])
            right = int(fields[2])
            strand = fields[3]
        fragment = (chr,left,right,strand)
        readCountsDict[fragment] = {}
        readCountsDict[fragment]['S'] = []
        readCountsDict[fragment]['C'] = []
        for BAM in STARRBAMFiles:
            readCountsDict[fragment]['S'].append(int(fields[FieldIDDict[BAM]]))
            TotalReadCountDict[BAM] += int(fields[FieldIDDict[BAM]])
        for BAM in ControlBAMFiles:
            readCountsDict[fragment]['C'].append(int(fields[FieldIDDict[BAM]]))
            TotalReadCountDict[BAM] += int(fields[FieldIDDict[BAM]])

    print len(readCountsDict), 'fragments found'

    TRClist = []
    for BAM in TotalReadCountDict.keys():
        TRClist.append(TotalReadCountDict[BAM])
#        print BAM, TotalReadCountDict[BAM], min(TRClist)
    
    MinimumTotalReadCounts = min(TRClist)

    LDSratioDict = {}
    for BAM in TotalReadCountDict.keys():
        LDSratioDict[BAM] = TotalReadCountDict[BAM]/(MinimumTotalReadCounts + 0.0)
#        print BAM, TotalReadCountDict[BAM], LDSratioDict[BAM]

    if doLDS:
        k = 0
        for fragment in readCountsDict.keys():
            k+=1
            if k % 5000000 == 0:
                print str(k/1000000) + 'M fragments processed in read count normalization'
            readCountsDict[fragment]['SN'] = []
            readCountsDict[fragment]['CN'] = []
            i = 0
            for BAM in STARRBAMFiles:
                counts = readCountsDict[fragment]['S'][i]
                normCounts = counts/LDSratioDict[BAM]
                readCountsDict[fragment]['SN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for BAM in ControlBAMFiles:
                counts = readCountsDict[fragment]['C'][i]
                normCounts = counts/LDSratioDict[BAM]
                readCountsDict[fragment]['CN'].append(int(round(normCounts)))
                i += 1
    else:
        MedianRatiosDict = {}
        for BAM in TotalReadCountDict.keys():
            MedianRatiosDict[BAM] = []
        k = 0
        for fragment in readCountsDict.keys():
            k+=1
            if k % 5000000 == 0:
                print str(k/1000000) + 'M fragments processed in read count normalization'
            i = 0
            n = 0
            Prod = 1
            for BAM in STARRBAMFiles:
                counts = readCountsDict[fragment]['S'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            i = 0
            for BAM in ControlBAMFiles:
                counts = readCountsDict[fragment]['C'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            if Prod == 0:
                continue
            else:
                ProdNthroot = math.pow(Prod,(1./n))
            i = 0
            for BAM in STARRBAMFiles:
                counts = readCountsDict[fragment]['S'][i]
                i += 1
                MedianRatiosDict[BAM].append(counts/ProdNthroot)
            i = 0
            for BAM in ControlBAMFiles:
                counts = readCountsDict[fragment]['C'][i]
                MedianRatiosDict[BAM].append(counts/ProdNthroot)
                i += 1
        MMratioDict = {}
        for BAM in TotalReadCountDict.keys():
#            print BAM, len(MedianRatiosDict[BAM])
            MMratioDict[BAM] = np.median(np.array(MedianRatiosDict[BAM]))
#            print BAM, TotalReadCountDict[BAM], MMratioDict[BAM]
        k = 0
        for fragment in readCountsDict.keys():
            k+=1
            if k % 5000000 == 0:
                print str(k/1000000) + 'M fragments processed in read count normalization'
            readCountsDict[fragment]['SN'] = []
            readCountsDict[fragment]['CN'] = []
            i = 0
            for BAM in STARRBAMFiles:
                counts = readCountsDict[fragment]['S'][i]
                normCounts = counts/MMratioDict[BAM]
                readCountsDict[fragment]['SN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for BAM in ControlBAMFiles:
                counts = readCountsDict[fragment]['C'][i]
                normCounts = counts/MMratioDict[BAM]
                readCountsDict[fragment]['CN'].append(int(round(normCounts)))
                i += 1
            
    print 'finished read count normalization'

    print 'estimating dispersion'

    NRCArray = []
    fragmentlist = readCountsDict.keys()
    fragmentlist.sort()
    k = len(fragmentlist)/NP

    j=0
    for i in range(NP):
        NRCDict = {}
        DispersionDict = {}
        if i+1 == NP:
            while j < len(fragmentlist):
                fragment = fragmentlist[j]
                NRCDict[fragment] = readCountsDict[fragment]
                j += 1
        else:
            while j < k*(i+1):
                fragment = fragmentlist[j]
                NRCDict[fragment] = readCountsDict[fragment]
                j += 1
        NRCArray.append((NRCDict,DispersionDict))

    p = Pool(NP)
    DispersionDicts = p.map(estimate_dispersion, NRCArray)

    DispersionDict = {}

    outfile = open(outprefix + '.first_pass_params','w')

    outline = '#fragment'
    for BAM in STARRBAMFiles:
        outline = outline + '\t' + BAM
    for BAM in ControlBAMFiles: 
        outline = outline + '\t' + BAM
    outline =  outline + '\tmeanSTARR\tmeanControl\tmeanTotal\tlogratio\tlogratioSE\talpha'
    outfile.write(outline + '\n')

    for DD in DispersionDicts:
        for fragment in DD.keys():
            (logratio,logratioSE,alpha) = DD[fragment]
            outline = str(fragment[0]) + ':' + str(fragment[1]) + '-' + str(fragment[2]) + '|' + str(fragment[3])
            mS = 0.0
            mC = 0.0
            mT = 0.0
            i = 0
            for BAM in STARRBAMFiles:
                outline = outline + '\t' + str(readCountsDict[fragment]['SN'][i])
                mS += readCountsDict[fragment]['SN'][i]
                i += 1
            j = 0
            for BAM in ControlBAMFiles:
                outline = outline + '\t' + str(readCountsDict[fragment]['CN'][j])
                mC += readCountsDict[fragment]['CN'][j]
                j += 1
            mT = (mS + mC)/(i+j)
            mS = mS/i
            mC = mC/j
            outline = outline + '\t' + str(mS) + '\t' + str(mC) + '\t' + str(mT)
            outline = outline + '\t' + str(logratio) + '\t' + str(logratioSE) + '\t' + str(alpha)
            outfile.write(outline + '\n')

    outfile.close()

run()
