##################################
#                                #
# Last modified 5/11/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import random
from sets import Set
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
import sqlite3 as sqlite
import sys
from commoncode import *

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

def countIntersects(regionDict1,regionDict2,chrlist,outfile):

    intersect=0
    stats1={}
    stats2={}
    for chr in chrlist:
        stats1[chr]=[]
        stats2[chr]=[]
        for i in regionDict1[chr].keys():
            stats1[chr].append(regionDict1[chr][i]['stop']-regionDict1[chr][i]['start'])
        for j in regionDict2[chr].keys():
            stats2[chr].append(regionDict2[chr][j]['stop']-regionDict2[chr][j]['start'])
 
        SUM1=0. + sum(stats1[chr])
        SUM2=0. + sum(stats2[chr])
        LEN1 = 1+len(stats1[chr])
        LEN2 = 1+len(stats2[chr])
        sStats1=SUM1/LEN1
        sStats2=SUM1/LEN2

        line = chr + '\t Average Dict1 regions length: ' + str(sStats1) + '\t Average Dict2 regions length: ' + str(sStats2) + '\n'
        outfile.write(line)
        for i in regionDict1[chr].keys():
            for j in regionDict2[chr].keys():
                if (regionDict2[chr][j]['start']>=regionDict1[chr][i]['start'] and regionDict2[chr][j]['start']<=regionDict1[chr][i]['stop']) or (regionDict2[chr][j]['stop']>=regionDict1[chr][i]['start'] and regionDict2[chr][j]['start']<=regionDict1[chr][i]['start']):
                    intersect+=1
                    break

    

    return intersect

def createRandomRegionDict(regionDict,chrDict):

    randomRegionDict={}
    chrEndPositions={}
    chrEndPositionsList=[]
    t=0
    for chr in chrDict.keys():
        if t==0:
            chrEnd=chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
            t+=1
        else:
            chrEnd=chrEndPositionsList[len(chrEndPositionsList)-1]+chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
        randomRegionDict[chr]={}
    print chrEndPositionsList
    maxCoordinate=max(chrEndPositions.keys())
    for chr in chrDict.keys():
        for i in regionDict[chr].keys():
            size=regionDict[chr][i]['start']-regionDict[chr][i]['stop']
            start=random.randint(0,maxCoordinate-size)
            stop=start+size
            if stop < chrEndPositionsList[0]:
                chromosome=chrEndPositions[chrEndPositionsList[0]]
                randomRegionDict[chromosome][i]={}
                randomRegionDict[chromosome][i]['start']=start
                randomRegionDict[chromosome][i]['stop']=stop
                continue
            else:
                for j in range(1,len(chrEndPositionsList)): 
                    if chrEndPositionsList[j]>stop and start>chrEndPositionsList[j-1]:
                        chromosome=chrEndPositions[chrEndPositionsList[j]]
                        randomRegionDict[chromosome][i]={}                        
                        randomRegionDict[chromosome][i]['start']=start
                        randomRegionDict[chromosome][i]['stop']=stop
                        break
                    if chrEndPositionsList[j]>start and stop>chrEndPositionsList[j]:
                        start=random.randint(chrEndPositionsList[j-1],chrEndPositionsList[j]-size)
                        chromosome=chrEndPositions[chrEndPositionsList[j]]
                        randomRegionDict[chromosome][i]={}                        
                        randomRegionDict[chromosome][i]['start']=start
                        randomRegionDict[chromosome][i]['stop']=stop
                        break
                        
    return randomRegionDict

def createRandomRegionDictRepeatMasked(regionDict,chrDict,repeatDict):

    randomRegionDict={}
    chrEndPositions={}
    chrEndPositionsList=[]
    t=0
    for chr in chrDict.keys():
        if t==0:
            chrEnd=chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
            t+=1
        else:
            chrEnd=chrEndPositionsList[len(chrEndPositionsList)-1]+chrDict[chr]
            chrEndPositionsList.append(chrEnd)
            chrEndPositions[chrEnd]=chr
        randomRegionDict[chr]={}
    print chrEndPositionsList
    maxCoordinate=max(chrEndPositions.keys())
    for chr in chrDict.keys():
        c=0
        while c < len(regionDict[chr].keys()):
            i=regionDict[chr].keys()[c]
            size=regionDict[chr][i]['start']-regionDict[chr][i]['stop']
            start=random.randint(0,maxCoordinate-size)
            stop=start+size
            inrepeat=False
            if start in repeatDict[chr] or stop in repeatDict[chr] or int((start-stop)/2.0) in repeatDict[chr]:
                inrepeat=True
            if inrepeat:
                continue
            if stop < chrEndPositionsList[0]:
                chromosome=chrEndPositions[chrEndPositionsList[0]]
                randomRegionDict[chromosome][i]={}
                randomRegionDict[chromosome][i]['start']=start
                randomRegionDict[chromosome][i]['stop']=stop
                c+=1
                print chr, c
                continue
            else:
                for j in range(1,len(chrEndPositionsList)): 
                    if chrEndPositionsList[j]>stop and start>chrEndPositionsList[j-1]:
                        chromosome=chrEndPositions[chrEndPositionsList[j]]
                        randomRegionDict[chromosome][i]={}                        
                        randomRegionDict[chromosome][i]['start']=start
                        randomRegionDict[chromosome][i]['stop']=stop
                        c+=1
                        print chr, c
                        break
                    else:
                        continue

    return randomRegionDict


def run():

    if len(sys.argv) < 7:
        print 'usage: python %s genome file1 [chromField1] file2 [chromField2] iterationsNumber otufilename [-repeatmask filespath] [-cache size]' % sys.argv[0]
        sys.exit(1)

    genome = sys.argv[1]
    file1 = sys.argv[2]
    chromField1 = int(sys.argv[3])
    file2 = sys.argv[4]
    chromField2 = int(sys.argv[5])
    iterations = int(sys.argv[6])
    outfilename = sys.argv[7]

    cachePages = -1
    doCache = False
    if '-cache' in sys.argv:
        doCache = True
        cachePages =  int(sys.argv[sys.argv.index('-cache') + 1])

    doRMSK = False
    if '-repeatmask' in sys.argv:
        doRMSK = True
        rmaskdir=sys.argv[sys.argv.index('-repeatmask') + 1]
        repeatDict={}
        files = os.listdir(rmaskdir)
        for filename in files:
            if 'rmsk' not in filename:
                continue
            print filename
            infile = open(rmaskdir + '/' + filename)
            chrom=filename.split('_')[0]
            print chrom
            repeatDict[chrom]=[]
            for entry in infile:
                fields = entry.strip().split('\t')
                start = int(fields[6])
                stop = int(fields[7])
                for i in range(start,stop):
                    repeatDict[chrom].append(i)

    outfile = open(outfilename, 'w')

    hg=Genome(genome)

    allchrlist = hg.allChromNames()
    chrlist=[]
    for j in allchrlist:
        if j.find('rand')==-1 and j.find('M')==-1 and j.find('Un')==-1:
            chrlist.append('chr'+j)
    chrDict={}
    for chr in chrlist:
        length=len(hg.getChromosomeSequence(chr.split('chr')[1]))
        chrDict[chr]=length

    listoflines = open(file1)
    lineslist = listoflines.readlines()
    chrlist=[]
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField1]
        if chr.find('rand')==-1:
            chrlist.append(chr)
    listoflines = open(file2)
    lineslist = listoflines.readlines()
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField2]
        if chr.find('rand')==-1:
            chrlist.append(chr)
    chrlist=list(Set(chrlist))
    print chrlist

    file1Dict={}
    file2Dict={}
    for chr in chrlist:
        file1Dict[chr]={}
        file2Dict[chr]={}

    listoflines = open(file1)
    lineslist = listoflines.readlines()
    i=0
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField1]
        if chr.find('rand')!=-1:
            continue
        file1Dict[chr][i]={}
        file1Dict[chr][i]['start']=int(fields[chromField1+1])
        file1Dict[chr][i]['stop']=int(fields[chromField1+2])
        file1Dict[chr][i]['intersected']=0
        i+=1

    listoflines = open(file2)
    lineslist = listoflines.readlines()
    i=0;
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chromField1]
        if chr.find('rand')!=-1:
            continue
        file2Dict[chr][i]={}
        file2Dict[chr][i]['start']=int(fields[chromField2+1])
        file2Dict[chr][i]['stop']=int(fields[chromField2+2])
        file2Dict[chr][i]['intersected']=0
        i+=1

    intersect=countIntersects(file1Dict,file2Dict,chrlist,outfile)

    outfile.write('Intersecting regions: ' + str(intersect) + '\n')

    outfile.write('-----\n')
    outfile.write('-----\n')
    outfile.write('Random Sampling Iterations: \n')

    if doRMSK:
        for i in range(0,iterations):
            line = "Iteration" + str(i)
            print line
            randomDict1=createRandomRegionDictRepeatMasked(file1Dict,chrDict,repeatDict)
            randomDict2=createRandomRegionDictRepeatMasked(file2Dict,chrDict,repeatDict)
            intersect=countIntersects(randomDict1,randomDict2,chrlist,outfile)
            outfile.write(str(intersect)+'\n')
    else:
        for i in range(0,iterations):
            line = "Iteration" + str(i)
            print line
            randomDict1=createRandomRegionDict(file1Dict,chrDict)
            randomDict2=createRandomRegionDict(file2Dict,chrDict)
            intersect=countIntersects(randomDict1,randomDict2,chrlist,outfile)
            outfile.write(str(intersect)+'\n')

    outfile.close()

run()
