##################################
#                                #
# Last modified 2018/06/25       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import math
import string
from sets import Set
import os
# from multiprocessing import Pool
# from threading import Thread
import numpy as np
# from patsy import dmatrices, dmatrix, demo_data
# from scipy.stats import nbinom
# from statsmodels import robust
# from statsmodels.discrete.discrete_model import NegativeBinomial
# import statsmodels.api as sm

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s table STARRfields Controlfields outfilename [-norm median-of-ratios]' % sys.argv[0]
        print '\tNote: STARRfields Controlfields should be comma separated'
        sys.exit(1)

    readCounts = sys.argv[1]
    STARRFieldIDs = []
    fields = sys.argv[2].split(',')
    for ID in fields:
        STARRFieldIDs.append(int(ID))
    ControlFieldIDs = []
    fields = sys.argv[3].split(',')
    for ID in fields:
        ControlFieldIDs.append(int(ID))

    outfile = open(sys.argv[4], 'w')

    doLDS = True
    if '-norm' in sys.argv:
        if sys.argv[sys.argv.index('-norm') + 1] == 'median-of-ratios':
            print 'will apply the median-of-ratios normalization'
            doLDS = False
    else:
        print 'will apply lds normalization'

    TotalReadCountDict = {}
    for ID in STARRFieldIDs:
        TotalReadCountDict[ID] = 0
    for ID in ControlFieldIDs:
        TotalReadCountDict[ID] = 0

    readCountsDict = {}

    linelist = open(readCounts)
    for line in linelist:
        fields = line.strip().split('\t')
        if line.startswith('#'):
            outline = '#'
            for ID in STARRFieldIDs:
                outline = outline + '\t' + fields[ID]
            for ID in ControlFieldIDs:
                outline = outline + '\t' + fields[ID]
            outfile.write(outline + '\n')
            continue
        for ID in STARRFieldIDs:
            TotalReadCountDict[ID] += float(fields[ID])
        for ID in ControlFieldIDs:
            TotalReadCountDict[ID] += float(fields[ID])
        if fields[0].startswith('complement'):
            continue
        region = fields[0]
        if readCountsDict.has_key(region):
            pass
        else:
            readCountsDict[region] = {}
            readCountsDict[region]['C'] = []
            readCountsDict[region]['S'] = []
            readCountsDict[region]['CN'] = []
            readCountsDict[region]['SN'] = []
        for ID in STARRFieldIDs:
            readCountsDict[region]['S'].append(float(fields[ID]))
        for ID in ControlFieldIDs:
            readCountsDict[region]['C'].append(float(fields[ID]))

    print len(readCountsDict), 'region found'

    TRClist = []
    for ID in TotalReadCountDict.keys():
        TRClist.append(TotalReadCountDict[ID])

    MinimumTotalReadCounts = min(TRClist)

    LDSratioDict = {}
    for ID in TotalReadCountDict.keys():
        LDSratioDict[ID] = TotalReadCountDict[ID]/(MinimumTotalReadCounts + 0.0)

    if doLDS:
        for region in readCountsDict.keys():
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['S'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['SN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['C'][i]
                normCounts = counts/LDSratioDict[ID]
                readCountsDict[region]['CN'].append(int(round(normCounts)))
                i += 1
    else:
        MedianRatiosDict = {}
        for ID in TotalReadCountDict.keys():
            MedianRatiosDict[ID] = []
        for region in readCountsDict.keys():
            i = 0
            n = 0
            Prod = 1
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['S'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['C'][i]
                Prod = Prod*counts
                i += 1
                n += 1
            if Prod == 0:
                continue
            else:
                ProdNthroot = math.pow(Prod,(1./n))
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['S'][i]
                i += 1
                MedianRatiosDict[ID].append(counts/ProdNthroot)
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['C'][i]
                MedianRatiosDict[ID].append(counts/ProdNthroot)
                i += 1
        MMratioDict = {}
        for ID in TotalReadCountDict.keys():
            MMratioDict[ID] = np.median(np.array(MedianRatiosDict[ID]))
#            print ID, 'MMratioDict[ID]', MMratioDict[ID]
#	        sys.exit(1)
        for region in readCountsDict.keys():
            i = 0
            for ID in STARRFieldIDs:
                counts = readCountsDict[region]['S'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['SN'].append(int(round(normCounts)))
                i += 1
            i = 0
            for ID in ControlFieldIDs:
                counts = readCountsDict[region]['C'][i]
                normCounts = counts/MMratioDict[ID]
                readCountsDict[region]['CN'].append(int(round(normCounts)))
                i += 1
            
    print 'finished read count normalization'

    regions = readCountsDict.keys()
    regions.sort()
    for region in regions:
        outline = region
        i = 0
        for ID in STARRFieldIDs:
            outline = outline + '\t' + str(readCountsDict[region]['SN'][i])
            i += 1
        i = 0
        for ID in ControlFieldIDs:
            outline = outline + '\t' + str(readCountsDict[region]['CN'][i])
            i += 1
        outfile.write(outline + '\n')

    outfile.close()

run()

