##################################
#                                #
# Last modified 5/7/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
import sqlite3 as sqlite
import sys


try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

from commoncode import *

def countIntersects(regionDict1,regionDict2,chrlist,outfile):

    intersect=0
    stats1={}
    stats2={}
    for chr in chrlist:
        stats1[chr]=[]
        stats2[chr]=[]
        for i in regionDict1[chr].keys():
            stats1[chr].append(regionDict1[chr][i]['stop']-regionDict1[chr][i]['start'])
        for j in regionDict2[chr].keys():
            stats2[chr].append(regionDict2[chr][j]['stop']-regionDict2[chr][j]['start'])
 
        SUM1=0. + sum(stats1[chr])
        SUM2=0. + sum(stats2[chr])
        LEN1 = 0. + len(stats1[chr])
        LEN2 = 0. + len(stats2[chr])
        sstats1[chr]=SUM1
        sstats2[chr]=SUM2

        line = chr + '\t Average Dict1 regions length: ' + str(sstats1[chr]) + '\t Average Dict2 regions length: ' + str(sstats2[chr]) + '\n'
        outfile.write(line)
        for i in regionDict1[chr].keys():
            for j in regionDict2[chr].keys():
                if (regionDict2[chr][j]['start']>=regionDict1[chr][i]['start'] and regionDict2[chr][j]['start']<=regionDict1[chr][i]['stop']) or (regionDict2[chr][j]['stop']>=regionDict1[chr][i]['start'] and regionDict2[chr][j]['start']<=regionDict1[chr][i]['start']):
                    intersect+=1
                    break

    

    return intersect

def createRandomRegionDict(regionDict,chrDict):

    randomRegionDict={}
    chrEndPositions={}
    chrEndPositionsList=[]
    t=0
    for chr in chrDict.keys():
        if t==0:
            chrEnd=chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
            t+=1
        else:
            chrEnd=chrEndPositionsList[len(chrEndPositionsList)-1]+chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
        randomRegionDict[chr]={}
    print chrEndPositionsList
    maxCoordinate=max(chrEndPositions.keys())
    for chr in chrDict.keys():
        for i in regionDict[chr].keys():
            size=regionDict[chr][i]['start']-regionDict[chr][i]['stop']
            start=random.randint(0,maxCoordinate-size)
            stop=start+size
            if stop < chrEndPositionsList[0]:
                chromosome=chrEndPositions[chrEndPositionsList[0]]
                randomRegionDict[chromosome][i]={}
                randomRegionDict[chromosome][i]['start']=start
                randomRegionDict[chromosome][i]['stop']=stop
                continue
            else:
                for j in range(1,len(chrEndPositionsList)): 
                    if chrEndPositionsList[j]>stop and start>chrEndPositionsList[j-1]:
                        chromosome=chrEndPositions[chrEndPositionsList[j]]
                        randomRegionDict[chromosome][i]={}                        
                        randomRegionDict[chromosome][i]['start']=start
                        randomRegionDict[chromosome][i]['stop']=stop
                        break
                    if chrEndPositionsList[j]>start and stop>chrEndPositionsList[j]:
                        start=random.randint(chrEndPositionsList[j-1],chrEndPositionsList[j]-size)
                        chromosome=chrEndPositions[chrEndPositionsList[j]]
                        randomRegionDict[chromosome][i]={}                        
                        randomRegionDict[chromosome][i]['start']=start
                        randomRegionDict[chromosome][i]['stop']=stop
                        break
                        
    return randomRegionDict

# def createRandomRegionDictRepeatMasked(regionDict,chrDict,rmskdb):
#
#    db = sqlite.connect(rmskdb)                        
#    sql = db.cursor()
#                        
#    return randomRegionDict

def createRandomRegionDictRepeatMasked(regionDict,chrDict,repeatDict):

    randomRegionDict={}
    chrEndPositions={}
    chrEndPositionsList=[]
    t=0
    for chr in chrDict.keys():
        if t==0:
            chrEnd=chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
            t+=1
        else:
            chrEnd=chrEndPositionsList[len(chrEndPositionsList)-1]+chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
        randomRegionDict[chr]={}
    print chrEndPositionsList
    maxCoordinate=max(chrEndPositions.keys())
    for chr in chrDict.keys():
        c=0
        while c < len(regionDict[chr].keys()):
            i=regionDict[chr].keys()[c]
            size=regionDict[chr][i]['start']-regionDict[chr][i]['stop']
            start=random.randint(0,maxCoordinate-size)
            stop=start+size
            inrepeat=False
            for k in range(0,len(repeatDict[chr].keys())-2):
                if repeatDict[chr][k+1]['stop']>start and start>repeatDict[chr][k]['start'] or repeatDict[chr][k+1]['start']>start and stop>repeatDict[chr][k]['stop']:
                    inrepeat=True
                    break
            if inrepeat:
                continue
            if stop < chrEndPositionsList[0]:
                chromosome=chrEndPositions[chrEndPositionsList[0]]
                randomRegionDict[chromosome][i]={}
                randomRegionDict[chromosome][i]['start']=start
                randomRegionDict[chromosome][i]['stop']=stop
                c+=1
                print chr, c
                continue
            else:
                for j in range(1,len(chrEndPositionsList)): 
                    if chrEndPositionsList[j]>stop and start>chrEndPositionsList[j-1]:
                        chromosome=chrEndPositions[chrEndPositionsList[j]]
                        randomRegionDict[chromosome][i]={}                        
                        randomRegionDict[chromosome][i]['start']=start
                        randomRegionDict[chromosome][i]['stop']=stop
                        c+=1
                        print chr, c
                        break
                    else:
                        continue

    return randomRegionDict



def run():

    if len(sys.argv) < 7:
#        print 'usage: python %s genome file1 [chromField1] file2 [chromField2] iterationsNumber otufilename [-repeatmask dbasefilename] [-cache size]' % sys.argv[0]
        print 'usage: python %s genome file1 [chromField1] file2 [chromField2] iterationsNumber otufilename [-repeatmask filespath] [-cache size]' % sys.argv[0]
        sys.exit(1)

    genome = sys.argv[1]
    file1 = sys.argv[2]
    chromField1 = int(sys.argv[3])
    file2 = sys.argv[4]
    chromField2 = int(sys.argv[5])
    iterations = int(sys.argv[6])
    outfilename = sys.argv[7]

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

#    sql.execute("PRAGMA CACHE_SIZE = %d" % cachePages)
#    sql.execute("PRAGMA temp_store = MEMORY")

    doRMSK = False
    if '-repeatmask' in sys.argv:
        doRMSK = True
#        repeatDB =  sys.argv[sys.argv.index('-repeatmask') + 1])
        rmaskdir=sys.argv[sys.argv.index('-repeatmask') + 1]
        repeatDict={}
        files = os.listdir(rmaskdir)
        for filename in files:
            if 'rmsk' not in filename:
                continue
            print filename
            infile = open(rmaskdir + '/' + filename)
            chrom=filename.split('_')[0]
            print chrom
            repeatDict[chrom]={}
            i=0
            for entry in infile:
                fields = entry.strip().split('\t')
                start = int(fields[6])
                stop = int(fields[7])
                repeatDict[chrom][i]={}
                repeatDict[chrom][i]['start']=start
                repeatDict[chrom][i]['stop']=stop
                i+=1


    outfile = open(outfilename, 'w')

    hg=Genome(genome)

    allchrlist = hg.allChromNames()
    chrlist=[]
    for j in allchrlist:
        if j.find('rand')==-1 and j.find('M')==-1 and j.find('Un')==-1:
            chrlist.append('chr'+j)
    chrDict={}
    for chr in chrlist:
        length=len(hg.getChromosomeSequence(chr.split('chr')[1]))
        chrDict[chr]=length

    listoflines = open(file1)
    lineslist = listoflines.readlines()
    chrlist=[]
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField1]
        if chr.find('rand')==-1:
            chrlist.append(chr)
    listoflines = open(file2)
    lineslist = listoflines.readlines()
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField2]
        if chr.find('rand')==-1:
            chrlist.append(chr)
    chrlist=list(Set(chrlist))
    print chrlist

    file1Dict={}
    file2Dict={}
    for chr in chrlist:
        file1Dict[chr]={}
        file2Dict[chr]={}

    listoflines = open(file1)
    lineslist = listoflines.readlines()
    i=0
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField1]
        if chr.find('rand')!=-1:
            continue
        file1Dict[chr][i]={}
        file1Dict[chr][i]['start']=int(fields[chromField1+1])
        file1Dict[chr][i]['stop']=int(fields[chromField1+2])
        file1Dict[chr][i]['intersected']=0
        i+=1

    listoflines = open(file2)
    lineslist = listoflines.readlines()
    i=0;
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField1]
        if chr.find('rand')!=-1:
            continue
        file2Dict[chr][i]={}
        file2Dict[chr][i]['start']=int(fields[chromField2+1])
        file2Dict[chr][i]['stop']=int(fields[chromField2+2])
        file2Dict[chr][i]['intersected']=0
        i+=1

    intersect=countIntersects(file1Dict,file2Dict,chrlist,outfile)

    outfile.write('Intersecting regions: ' + str(intersect) + '\n')

    outfile.write('-----\n')
    outfile.write('-----\n')
    outfile.write('Random Sampling Iterations: \n')

    if doRMSK:
        for i in range(0,iterations):
            line = "Iteration" + str(i)
            print line
#            randomDict1=createRandomRegionDictRepeatMasked(file1Dict,chrDict,repeatDB)
#            randomDict2=createRandomRegionDictRepeatMasked(file2Dict,chrDict,repeatDB)
            randomDict1=createRandomRegionDictRepeatMasked(file1Dict,chrDict,repeatDict)
            randomDict2=createRandomRegionDictRepeatMasked(file2Dict,chrDict,repeatDict)
            intersect=countIntersects(randomDict1,randomDict2,chrlist,outfile)
            outfile.write(str(intersect)+'\n')
    else:
#        for i in range(0,iterations):
#            line = "Iteration" + str(i)
#            print line
#            randomDict1=createRandomRegionDictRepeatMasked(file1Dict,chrDict,repeatDB)
#            randomDict2=createRandomRegionDictRepeatMasked(file2Dict,chrDict,repeatDB)
#            intersect=countIntersects(randomDict1,randomDict2,chrlist,outfile)
#            outfile.write(str(intersect)+'\n')

        for i in range(0,iterations):
            line = "Iteration" + str(i)
            print line
            randomDict1=createRandomRegionDict(file1Dict,chrDict)
            randomDict2=createRandomRegionDict(file2Dict,chrDict)
            intersect=countIntersects(randomDict1,randomDict2,chrlist,outfile)
            outfile.write(str(intersect)+'\n')



    outfile.close()

run()
