##################################
#                                #
# Last modified 2017/06/01       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import math
import string
from sets import Set
import numpy as np
from scipy.stats import hypergeom

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s config individual_overlaps geneIDfieldID repeatFieldID outfilename [-log10]' % sys.argv[0]
        print '\tindividual_overlaps format: output from ExonRepeatOverlap.py'
        print '\tconfig gile: label <tab> filename <fieldID>'
        sys.exit(1)

    config = sys.argv[1]
    overlaps = sys.argv[2]
    geneIDfieldID = int(sys.argv[3])
    repeatFieldID = int(sys.argv[4])
    outfilename = sys.argv[5]

    doLog10 = False
    if '-log10' in sys.argv:
        doLog10 = True

    RepeatCountDict = {}
    GeneRepeatOverlapDict = {}

    linelist = open(overlaps)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[geneIDfieldID]
        repeat = fields[repeatFieldID]
        if GeneRepeatOverlapDict.has_key(geneID):
            pass
        else:
            GeneRepeatOverlapDict[geneID] = {}
        if repeat != 'nan':
            GeneRepeatOverlapDict[geneID][repeat] = 1
            if RepeatCountDict.has_key(repeat):
                pass
            else:
                RepeatCountDict[repeat] = 0
            RepeatCountDict[repeat] += 1

    TotalGenes = len(GeneRepeatOverlapDict.keys())

    print 'TotalGenes:', TotalGenes
 
    outfile = open(outfilename,'w')
    outline = '#label\trepeat\tTotalGenes\tGenesInLabel\tOverlapAllGenes\tOverlapInLabel\tp-val'
    outfile.write(outline + '\n')

#    total lincRNA
#    total matches
#    subset lincRNA
#    subset matches

    linelist1 = open(config)
    for line1 in linelist1:
        fields1 = line1.strip().split('\t')
        label = fields1[0]
        file = fields1[1]
        fieldID = int(fields1[2])
        linelist = open(file)
        RepeatCountLocalDict = {}
        TotalGenesLocal = 0
        print label
        for line in linelist:
            if line.startswith('#'):
                continue
            TotalGenesLocal += 1
            fields = line.strip().split('\t')
            geneID = fields[fieldID]
            if GeneRepeatOverlapDict.has_key(geneID):
                for repeat in GeneRepeatOverlapDict[geneID].keys():
                    if RepeatCountLocalDict.has_key(repeat):
                        pass
                    else:
                        RepeatCountLocalDict[repeat] = 0
                    RepeatCountLocalDict[repeat] += 1
        for repeat in RepeatCountLocalDict.keys():
            GlobalMatches = RepeatCountDict[repeat]
            LocalMatches = RepeatCountLocalDict[repeat]
            rv = hypergeom(TotalGenes, TotalGenesLocal, GlobalMatches)
            p = 1 - rv.cdf(LocalMatches)
            p = max(p,1e-300)
            if doLog10:
                p = -math.log10(p)
            outline = label + '\t' + repeat + '\t' + str(TotalGenes) + '\t' + str(TotalGenesLocal)
            outline = outline + '\t' + str(GlobalMatches) + '\t' + str(LocalMatches) + '\t' + str(p)
            outfile.write(outline + '\n')

    outfile.close()

run()

