##################################
#                                #
# Last modified 09/14/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s file1 chromField1 file2 chromField2 otufilenameprefix [-peak2contained Peakfield] [-minOverlap percentage-of-1st-file-region] [-singleField]' % sys.argv[0]
        print "       Use the -singleField option if regions are specified in a chrN:left-right format; add ',SF' to the chromField2 if that's the case"
        sys.exit(1)

    file1 = sys.argv[1]
    file2 = sys.argv[3]
    outfileprefix = sys.argv[5]
    minOverlap=0
    if '-minOverlap' in sys.argv:
        minOverlap=float(sys.argv[sys.argv.index('-minOverlap')+1])

    doSF1=False
    doSF2=False
    if '-singleField' in sys.argv:
        if ',SF' in sys.argv[2]:
            doSF1=True
            chromField1 = int(sys.argv[2].split(',')[0])
        else:
            chromField1 = int(sys.argv[2])
        if ',SF' in sys.argv[4]:
            doSF2=True
            chromField2 = int(sys.argv[4].split(',')[0])
        else:
            chromField2 = int(sys.argv[4])
    else:
        chromField1 = int(sys.argv[2])
        chromField2 = int(sys.argv[4])

    doPeak=False
    if '-peak2contained' in sys.argv:
        doPeak=True
        print 'requiring second set peaks to be contained within first set regions'
        peakField = int(sys.argv[sys.argv.index('-peak2contained') + 1])

    listoflines = open(file1)
    lineslist = listoflines.readlines()
    chrlist=[]
    for line in lineslist:
        if line[0]=='#':
            continue
        if len(line.strip())==0:
            continue
        fields=line.strip().split('\t')
        if doSF1:
            chr=fields[chromField1].split(':')[0]
            chrlist.append(chr)
        else:
            chr=fields[chromField1]
            chrlist.append(chr)
    listoflines = open(file2)
    lineslist = listoflines.readlines()
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')
        if doSF2:
            chr=fields[chromField2].split(':')[0]
            chrlist.append(chr)
        else:
            chr=fields[chromField2]
            chrlist.append(chr)
    chrlist=list(Set(chrlist))
    chrlist.sort()
    file1Dict={}
    file2Dict={}
    intersection_regions={}
    intersection1={}
    intersection2={}
    outersection1={}
    outersection2={}
    for chr in chrlist:
        file1Dict[chr]={}
        file2Dict[chr]={}
        intersection_regions[chr]={}
        intersection1[chr]={}
        intersection2[chr]={}
        outersection1[chr]={}
        outersection2[chr]={}

    listoflines = open(file1)
    lineslist = listoflines.readlines()
    i=0;
    for line in lineslist:
        if line[0]=='#':
            continue
        if len(line.strip())==0:
            continue
        fields=line.strip().split('\t')
        if doSF1:
            chr=fields[chromField1].split(':')[0]
            start=int(fields[chromField1].split(':')[1].split('-')[0])
            stop=int(fields[chromField1].split(':')[1].split('-')[1])
        else:
            chr=fields[chromField1]
            start=int(fields[chromField1+1])
            stop=int(fields[chromField1+2])
        file1Dict[chr][i]={}
        file1Dict[chr][i]['line']=line
        file1Dict[chr][i]['start']=start
        file1Dict[chr][i]['stop']=stop
        file1Dict[chr][i]['intersected']=0
        i+=1

    listoflines = open(file2)
    lineslist = listoflines.readlines()
    i=0;
    for line in lineslist:
        if line[0]=='#':
            continue
        if len(line.strip())==0:
            continue
        fields=line.strip().split('\t')
        if doSF2:
            chr=fields[chromField2].split(':')[0]
            start=int(fields[chromField2].split(':')[1].split('-')[0])
            stop=int(fields[chromField2].split(':')[1].split('-')[1])
        else:
            chr=fields[chromField2]
            start=int(fields[chromField2+1])
            stop=int(fields[chromField2+2])
        file2Dict[chr][i]={}
        file2Dict[chr][i]['line']=line
        file2Dict[chr][i]['start']=start
        file2Dict[chr][i]['stop']=stop
        file2Dict[chr][i]['intersected']=0
        if doPeak:
            file2Dict[chr][i]['peak']=int(fields[peakField])
        i+=1

    intersect=0
    for chr in chrlist:
        print chr
        for i in file1Dict[chr].keys():
            for j in file2Dict[chr].keys():
                if doPeak:
                    if (file2Dict[chr][j]['peak']>=file1Dict[chr][i]['start'] and file2Dict[chr][j]['peak']<=file1Dict[chr][i]['stop']):
                        intersection_regions[chr][intersect]={}
                        intersection_regions[chr][intersect]['start']=max(file1Dict[chr][i]['start'],file2Dict[chr][j]['start'])
                        intersection_regions[chr][intersect]['stop']=min(file1Dict[chr][i]['stop'],file2Dict[chr][j]['stop'])
                        intersect_length=intersection_regions[chr][intersect]['stop']=intersection_regions[chr][intersect]['start']
                        region1_length=file1Dict[chr][i]['stop']-file1Dict[chr][i]['start']
                        fraction=(intersect_length+0.)/(region1_length+0.)
                        if fraction>=minOverlap:
                            intersect+=1
                            file1Dict[chr][i]['intersected']=1
                            file2Dict[chr][j]['intersected']=1
                        else:
                            del intersection_regions[chr][intersect]
                        break
                else:
                    if ((file2Dict[chr][j]['start']>=file1Dict[chr][i]['start'] and file2Dict[chr][j]['start']<=file1Dict[chr][i]['stop']) or (file2Dict[chr][j]['stop']>=file1Dict[chr][i]['start'] and file2Dict[chr][j]['start']<=file1Dict[chr][i]['start'])) or                        ((file2Dict[chr][i]['start']>=file1Dict[chr][j]['start'] and file2Dict[chr][i]['start']<=file1Dict[chr][j]['stop']) or (file2Dict[chr][i]['stop']>=file1Dict[chr][j]['start'] and file2Dict[chr][i]['start']<=file1Dict[chr][j]['start'])):
                        intersection_regions[chr][intersect]={}
                        intersection_regions[chr][intersect]['start']=max(file1Dict[chr][i]['start'],file2Dict[chr][j]['start'])
                        intersection_regions[chr][intersect]['stop']=min(file1Dict[chr][i]['stop'],file2Dict[chr][j]['stop'])
                        intersect_length=intersection_regions[chr][intersect]['stop']-intersection_regions[chr][intersect]['start']
                        region1_length=file1Dict[chr][i]['stop']-file1Dict[chr][i]['start']
                        fraction=(intersect_length+0.)/(region1_length+0.)
                        if fraction>=minOverlap:
                            intersect+=1
                            file1Dict[chr][i]['intersected']=1
                            file2Dict[chr][j]['intersected']=1
                        else:
                            del intersection_regions[chr][intersect]
                        break

    outersect=0
    intersect=0
    for chr in chrlist:
        for i in file1Dict[chr].keys():
            if file1Dict[chr][i]['intersected']==0:
                outersection1[chr][outersect]=file1Dict[chr][i]['line']
                outersect+=1
            if file1Dict[chr][i]['intersected']==1:
                intersection1[chr][intersect]=file1Dict[chr][i]['line']
                intersect+=1

    outersect=0
    intersect=0
    for chr in chrlist:
        for i in file2Dict[chr].keys():
            if file2Dict[chr][i]['intersected']==0:
                outersection2[chr][outersect]=file2Dict[chr][i]['line']
                outersect+=1
            if file2Dict[chr][i]['intersected']==1:
                intersection2[chr][intersect]=file2Dict[chr][i]['line']
                intersect+=1

    outfilename=outfileprefix+'-intersect_regions.txt'
    outfile = open(outfilename, 'w')
    for chr in intersection_regions.keys():
        for i in intersection_regions[chr].keys():
            line=chr+'\t'+str(intersection_regions[chr][i]['start'])+'\t'+str(intersection_regions[chr][i]['stop'])+'\n'
            outfile.write(line)
    outfile.close()

    outfilename=outfileprefix+'-intersection1'
    outfile = open(outfilename, 'w')
    for chr in intersection1.keys():
        for i in intersection1[chr].keys():
            outfile.write(intersection1[chr][i])
    outfile.close()

    outfilename=outfileprefix+'-intersection2'
    outfile = open(outfilename, 'w')
    for chr in intersection2.keys():
        for i in intersection2[chr].keys():
            outfile.write(intersection2[chr][i])
    outfile.close()

    outfilename=outfileprefix+'-outersection1'
    outfile = open(outfilename, 'w')
    for chr in outersection1.keys():
        for i in outersection1[chr].keys():
            outfile.write(outersection1[chr][i])
    outfile.close()

    outfilename=outfileprefix+'-outersection2'
    outfile = open(outfilename, 'w')
    for chr in outersection2.keys():
        for i in outersection2[chr].keys():
            outfile.write(outersection2[chr][i])
    outfile.close()

run()
