##################################
#                                #
# Last modified 02/08/2009       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy
from scipy.stats import norm
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s genome testRDSfilename refRDSfilename p-value ratio slidestep(bp) outputfilename [-chrom number] [-minRPKM value] [-MHcorrection] [-cache size] ' % sys.argv[0]
        sys.exit(1)
    
    genome = sys.argv[1]
    testRDSfilename = sys.argv[2]
    refRDSfilename = sys.argv[3]
    pvalue = float(sys.argv[4])
    ratio = float(sys.argv[5])
    slidestep=int(sys.argv[6])
    outfilename = sys.argv[7]
  
    outfile = open(outfilename, 'w')

    hg = Genome(genome)
    chromlist=hg.allChromNames()
    chrDict={}
    GenomeSize = 0
    if '-chrom' in sys.argv:
        chromlist=[]
        chromlist.append(sys.argv[sys.argv.index('-chrom')+1])
    for chr in chromlist:
        if 'rand' in chr:
            continue
        chrDict[chr]=len(hg.getChromosomeSequence(chr))
        GenomeSize+=chrDict[chr]

    print chrDict

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])
    testRDS = readDataset(testRDSfilename, verbose = True, cache=True)
    refRDS = readDataset(refRDSfilename, verbose = True, cache=True)

    #sqlite default_cache_size is 2000 pages
    if cachePages > testRDS.getDefaultCacheSize():
        testRDS.setDBcache(cachePages)

    #sqlite default_cache_size is 2000 pages
    if cachePages > testRDS.getDefaultCacheSize():
        testRDS.setDBcache(cachePages)

    metadata = testRDS.getMetadata()

    testReadNumber = len(testRDS)

    metadata = refRDS.getMetadata()
    
    refReadNumber = len(refRDS)

    T=norm.ppf(0.5*pvalue)
    W=int(((testReadNumber*ratio*ratio+refReadNumber)*GenomeSize*T*T)/((1-ratio)*(1-ratio)*testReadNumber*refReadNumber))
    print 'will use window size of ', W, 'slide step of', slidestep

    dominRPKM=False
    if '-minRPKM' in sys.argv:
        minRPKM=float(sys.argv[sys.argv.index('-minRPKM')+1])
        dominRPKM=True
        normalizeRefBy=(refReadNumber/1000000.)*(W/1000.)
        normalizeTestBy=(testReadNumber/1000000.)*(W/1000.0)

    doMHCorrection=False
    if '-MHcorrection' in sys.argv:
        doMHCorrection=True

    countsDict={}
    testcounts=[]
    refcounts=[]
    for chr in chrDict.keys():
        print 'chr'+str(chr)
        i=0
        while i+W < chrDict[chr]:
            testvalue=testRDS.getCounts(chrom='chr'+str(chr), rmin=i, rmax=i+W, uniqs=True, multi=False, splices=False, reportCombined=True)
            refvalue=refRDS.getCounts(chrom='chr'+str(chr), rmin=i, rmax=i+W, uniqs=True, multi=False, splices=False, reportCombined=True)
            countsDict[(chr,i,i+W)]=[testvalue,refvalue]
            testcounts.append(testvalue)
            refcounts.append(refvalue)
            if i % 10000000 == 0:
                print i
            i=i+slidestep

    testcounts=numpy.array(testcounts)
    refcounts=numpy.array(refcounts)
    testmean=numpy.mean(testcounts)
    refmean=numpy.mean(refcounts)
    testStdDev=numpy.std(testcounts)
    refStdDev=numpy.std(refcounts)
    print testmean, refmean, testStdDev, refStdDev

    tvalues=[]

    for (chr,start,stop) in countsDict.keys():
        (X,Y)=countsDict[(chr,start,stop)]
        z=(X+1)/(Y+1)
        r=(z*refReadNumber)/testReadNumber
        t=(testmean*z-refmean)/math.sqrt(refStdDev*refStdDev*z*z+testStdDev*testStdDev)
        countsDict[(chr,start,stop)]=(X,Y,t,r)
        tvalues.append(t)
    
    tvalues=numpy.array(tvalues)
    tvaluesmean=numpy.mean(tvalues)
    tvaluesstd=numpy.std(tvalues)

    outfile.write('#stats:\n')
    outfile.write('#Genome Size='+str(GenomeSize)+'\n')
    outfile.write('#Window Size='+str(W)+'\n')
    outfile.write('#Sliding Window Step='+str(slidestep)+'\n')
    outfile.write('#Total regions examined='+str(len(countsDict.keys()))+'\n')
    outfile.write('#minRPKM='+str(dominRPKM)+'\n')
    outfile.write('#Multiple hypothesis testing correction='+str(doMHCorrection)+'\n')
    outfile.write('#Multiple hypothesis testing correction factor='+str(pvalue/(GenomeSize/(W+0.0)))+'\n')
    outfile.write('#chr\tStart\tStop\tTest Reads\tReference Reads\tratio\tt-value\tp-value\n')

    for (chr,start,stop) in countsDict.keys():
        (X,Y,t,r)=countsDict[(chr,start,stop)]
        ttrans=(t-tvaluesmean)/tvaluesstd
        if dominRPKM:
            XRPKM=X/normalizeTestBy
            YRPKM=Y/normalizeRefBy
            if X < minRPKM and Y < minRPKM:
                continue
        if r >= 1:
            p=2*(1-norm.cdf(ttrans))
        if r < 1:
            p=2*norm.cdf(ttrans)
        if doMHCorrection:
            pvalue=pvalue/(GenomeSize/(W+0.0))
        if p <= pvalue:
            outline = '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t' % (chr, start, stop, X, Y, r, t, p)
            outfile.write(outline + '\n')

    outfile.close()
   
run()
