##################################
#                                #
# Last modified 2018/05/19       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
import numpy as np
from sets import Set
from operator import add
from scipy import stats
from scipy.cluster import hierarchy

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s  datafilename IDs outfilename [-verbose] [-excludeHOT fraction] [-subsample SampleSize N_iterations]' % sys.argv[0]
        print '\tformat of IDs: either comma separated or start-end (including end)'
        print '\toutput from combinePeaks.py is assumed, with a header line'
        print '\tthe script will print out the -log10p values by default'
        print '\tthe script will take stdin as input'
        sys.exit(1)

    datafilename = sys.argv[1]
    IDfields = []
    if '-' in sys.argv[2]:
        fields1 = int(sys.argv[2].split('-')[0])
        fields2 = int(sys.argv[2].split('-')[1])
        for ID in range(fields1,fields2+1):
            IDfields.append(ID)
    else:
        fields = sys.argv[2].split(',')
        for ID in fields:
            IDfields.append(int(ID))
    LabelDict={}
    DataDict={}
    outfilename = sys.argv[3]

    doVerbose = False
    if '-verbose' in sys.argv:
        doVerbose = True

    doExcludeHOT = False
    if '-excludeHOT' in sys.argv:
        doExcludeHOT = True
        HOTcutoff = float(sys.argv[sys.argv.index('-excludeHOT') + 1])
        print 'will exclude regions found in more than ' + str(100*HOTcutoff) + '% of datasets'

    doSubSample = False
    if '-subsample' in sys.argv:
        doSubSample = True
        SampleSize = int(sys.argv[sys.argv.index('-subsample') + 1])
        N_samples = int(sys.argv[sys.argv.index('-subsample') + 2])
        print 'will subsample peak sets down to', SampleSize, 'peaks, will take the average from', N_samples, 'samplings'

    outfile = open(outfilename, 'w')

    if datafilename == '-':
        lineslist  = sys.stdin
    else:
        lineslist  = open(datafilename)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line[0]=='#':
            fields=line.split('\n')[0].split('\t')
            for ID in IDfields:
                LabelDict[ID] = fields[ID]
                print ID, fields[ID]
                DataDict[fields[ID]] = []
            continue
        fields = line.strip().split('\t')
        if doExcludeHOT:
            Total = 0.0
            HOT = 0.0
            for ID in IDfields:
                if float(fields[ID]) == 0:
                    pass
                else:
                    HOT += 1
                Total += 1
            if HOT/Total >= HOTcutoff:
                continue
        for ID in IDfields:
            if len(fields) <= ID:
                continue
            if float(fields[ID]) == 0:
                DataDict[LabelDict[ID]].append(0)
            else:
                DataDict[LabelDict[ID]].append(1)

    print 'finished inputting table'

    TFs = DataDict.keys()
    TFs.sort()

    outline = '#TF'
    for TF1 in TFs:
        outline = outline + '\t' + TF1
    outfile.write(outline + '\n')

    for TF1 in TFs:
        outline = TF1
        print TF1
        for TF2 in TFs:
            if doSubSample:
                P1 = DataDict[TF1].count(1)
                P2 = DataDict[TF2].count(1)
                pvals = []
                if P1 == 0 or P2 == 0:
                    outline = outline + '\t' + 'nan'
                    continue
                else:
                    pass
                for i in range(N_samples):
                    Z = zip(DataDict[TF1],DataDict[TF2])
                    Z.sort()
                    Z.reverse()
                    R1 = random.sample(Z[0:P1],min(SampleSize,P1))
                    Z1, Z2  = zip(*R1)
                    common = map(add, Z1, Z2)
                    C1 = common.count(2)
                    Z = zip(DataDict[TF2],DataDict[TF1])
                    Z.sort()
                    Z.reverse()
                    R2 = random.sample(Z[0:P2],min(SampleSize,P2))
                    Z1, Z2  = zip(*R2)
                    common = map(add, Z1, Z2)
                    C2 = common.count(2)
                    oddsratio, pvalue = stats.fisher_exact([[min(SampleSize,P1) - C1, C1], [min(SampleSize,P2) - C2, C2]])
                    pvals.append(pvalue)
                if np.mean(pvals) > 0:
                    outline = outline + '\t' + str(-math.log10(np.mean(pvals)))
                else:
                    outline = outline + '\t' + str(324)
            else:
                common = map(add, DataDict[TF1], DataDict[TF2])
                C = common.count(2)
                U1 = DataDict[TF1].count(1) - C
                U2 = DataDict[TF2].count(1) - C
                oddsratio, pvalue = stats.fisher_exact([[U1, C], [U2, C]])
                if doVerbose:
                    print TF1, TF2, oddsratio, pvalue, U1, C, U2, C
                if pvalue > 0:
                    outline = outline + '\t' + str(-math.log10(pvalue))
                else:
                    outline = outline + '\t' + str(324)

        outfile.write(outline+'\n')
        
    outfile.close()

run()

