##################################
#                                #
# Last modified 04/12/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from scipy.stats import hypergeom
from scipy import special
import numpy as np

def logchoose(n, k):

    lgn1 = special.gammaln(n+1)
    lgk1 = special.gammaln(k+1)
    lgnk1 = special.gammaln(n-k+1)
    return lgn1 - (lgnk1 + lgk1)

def gauss_hypergeom(x, r, b, n):

    return logchoose(r, x) + logchoose(b, n-x) - logchoose(r+b, n)

def hypergeom_cdf(x, r, b, n):

    pdf=[]
    for i in range(x):
        pdf.append(gauss_hypergeom(i, r, b, n))

    return np.logaddexp.reduce(pdf)

def hypergeom_pvalue(x, r, b, n):

    pdf=[]
    for i in range(x,n+1):
        pdf.append(gauss_hypergeom(i, r, b, n))

    return np.logaddexp.reduce(pdf)

def hypergeom_stats(r, b, n):

    mean = b*n/(r+0.0)
    std = math.sqrt(((b*n*(r-b))/(r*r*1.0))*((r-n)/(r-1.0)))

    return (mean, std)

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s repeatMasker chrFieldID RepeatClassFieldID listOfRegionCalls [ERANGE | MACS] chr.sizes outfile [-ZScore]' % sys.argv[0]
        print '       listOfRegionCalls format: lable <tab> filename' 
        sys.exit(1)

    repeatMasker = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    RepeatFieldID = int(sys.argv[3])
    listOfFiles = sys.argv[4]
    type = sys.argv[5]
    chrInfo = sys.argv[6]
    outfile = open(sys.argv[7],'w')

    doZ=False
    if '-ZScore' in sys.argv:
        doZ=True
        print 'will output z-scores'

    GenomeLength=0
    listoflines = open(chrInfo)
    for line in listoflines:
        fields=line.split('\n')[0].split('\t')
        GenomeLength+=int(fields[1])

    listoflines = open(listOfFiles)
    PeakDict={}
    RegionDict={}
    for labelline in listoflines:
        labelfields=labelline.split('\n')[0].split('\t')
        label=labelfields[0]
        RegionDict[label]={}
        file=labelfields[1]
        lineslist=open(file)
        totalcalls=0
        for line in lineslist:
            if line[0]=='#':
                continue
            fields=line.split('\n')[0].split('\t')
            if type == 'ERANGE':
                chr=fields[1]
                peak=int(fields[9])
            if type == 'MACS':
                chr=fields[1]
                peak=int(fields[0])+int(fields[4])
            if PeakDict.has_key(chr):
                pass
            else:
                PeakDict[chr]={}
            if RegionDict[label].has_key(chr):
                pass
            else:
                RegionDict[label][chr]={}
            PeakDict[chr][peak]=''
            RegionDict[label][chr][peak]=''
            totalcalls+=1
        RegionDict[label]['TotalCalls']=totalcalls

    print 'finished inputting regions'

    RepeatDict={}
    listoflines = open(repeatMasker)
    k=0
    for line in listoflines:
        k+=1
        if k % 1000000 == 0:
            print k, 'lines processed'
        fields=line.split('\n')[0].split('\t')
        chr=fields[chrFieldID]
        left=int(fields[chrFieldID+1])
        right=int(fields[chrFieldID+2])
        repeatClass=fields[RepeatFieldID]
        if RepeatDict.has_key(repeatClass):
            pass
        else:
            RepeatDict[repeatClass]=[]
        RepeatDict[repeatClass].append((chr,left,right))

    print 'finished inputting repeat annotation'

    labels=RegionDict.keys()
    repeats=RepeatDict.keys()

    repeats.sort()
    labels.sort()

    outline='#'
    for label in labels:
        outline=outline+'\t'+label
    outfile.write(outline+'\n')

    for repeat in repeats:
        outline=repeat
        TotalLength=0
        OverlappingPeaks=[]
        for (chr,left,right) in RepeatDict[repeat]:
            TotalLength+=(right-left)
            for i in range(left,right):
                if PeakDict[chr].has_key(i):
                    OverlappingPeaks.append((chr,i))
        AverageLength=TotalLength/(len(RepeatDict[repeat])+0.0)
        N=int(GenomeLength/AverageLength)
        M=len(RepeatDict[repeat])
        for label in labels:
            n=RegionDict[label]['TotalCalls']
            x=0
            for (chr,peak) in OverlappingPeaks:
                if RegionDict[label].has_key(chr) and RegionDict[label][chr].has_key(peak):
                    x+=1
            if doZ:
                (mean,std) = hypergeom_stats(N,M,n)
                try:
                    Z = (x - mean)/std
                except:
                    print x,N,M,n,mean,std
                outline=outline+'\t'+str(Z)
            else:
                pvalue=hypergeom_pvalue(x, N, M, n)
                outline=outline+'\t'+str(pvalue)
        print outline
        outfile.write(outline+'\n')

    outfile.close()

run()