##################################
#                                #
# Last modified 2018/03/13       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
from sets import Set

def intersect(file1,field1,file2,field2,doPeak,minOverlap):

    if file1.endswith('.bz2'):
        cmd1 = 'bzip2 -cd ' + file1
    elif file1.endswith('.gz'):
        cmd1 = 'gunzip -c ' + file1
    elif file1.endswith('.zip'):
        cmd1 = 'unzip -p ' + file1
    else:
        cmd1 = 'cat ' + file1
    p1 = os.popen(cmd1, "r")

    if file2.endswith('.bz2'):
        cmd2 = 'bzip2 -cd ' + file2
    elif file2.endswith('.gz'):
        cmd2 = 'gunzip -c ' + file2
    elif file2.endswith('.zip'):
        cmd2 = 'unzip -p ' + file2
    else:
        cmd2 = 'cat ' + file2
    p2 = os.popen(cmd2, "r")

    chrlist=[]
    line = 'line'
    while line != '':
        line = p1.readline().strip()
        if line == '':
            break
        if line.startswith('#') or line.startswith('track type'):
            continue
        if len(line.strip())==0:
            continue
        fields=line.split('\n')[0].split('\t')
        if fields[field1+1].isalpha():
            print 'skipping', line
            continue
        chr=fields[field1]
        chrlist.append(chr)

    line = 'line'
    while line != '':
        line = p2.readline().strip()
        if line == '':
            break
        if line.startswith('#') or line.startswith('track type'):
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[field2]
        chrlist.append(chr)

    chrlist=list(Set(chrlist))
    chrlist.sort()
    file1Dict={}
    file2Dict={}
    intersection_regions={}
    intersection1={}
    intersection2={}
    outersection1={}
    outersection2={}
    for chr in chrlist:
        file1Dict[chr]={}
        file2Dict[chr]={}
        intersection_regions[chr]={}
        intersection1[chr]={}
        intersection2[chr]={}
        outersection1[chr]={}
        outersection2[chr]={}

    p1 = os.popen(cmd1, "r")
    p2 = os.popen(cmd2, "r")

    i=0
    File1Regions=0.0;

    line = 'line'
    while line != '':
        line = p1.readline().strip()
        if line == '':
            break
        if line.startswith('#') or line.startswith('track type'):
            continue
        if len(line.strip())==0:
            continue
        fields=line.strip().split('\t')
        if fields[field1+1].isalpha():
            print 'skipping', line
            continue
        chr=fields[field1]
        try:
            start=int(fields[field1+1])
            stop=int(fields[field1+2])
        except:
            continue
        file1Dict[chr][i]={}
        file1Dict[chr][i]['line']=line
        file1Dict[chr][i]['start']=start
        file1Dict[chr][i]['stop']=stop
        file1Dict[chr][i]['intersected']=0
        File1Regions+=1
        i+=1

    i=0;
    File2Regions=0.0;
    line = 'line'
    while line != '':
        line = p2.readline().strip()
        if line == '':
            break
        if line.startswith('#') or line.startswith('track type'):
            continue
        if len(line.strip())==0:
            continue
        fields=line.strip().split('\t')
        chr=fields[field2]
        try:
            start=int(fields[field2+1])
            stop=int(fields[field2+2])
        except:
            continue
        file2Dict[chr][i]={}
        file2Dict[chr][i]['line']=line
        file2Dict[chr][i]['start']=start
        file2Dict[chr][i]['stop']=stop
        file2Dict[chr][i]['intersected']=0
        if doPeak:
            file2Dict[chr][i]['peak']=int(fields[peakField])
        i+=1
        File2Regions+=1

    intersect=0
    for chr in chrlist:
        for i in file1Dict[chr].keys():
            for j in file2Dict[chr].keys():
                if doPeak:
                    if (file2Dict[chr][j]['peak']>=file1Dict[chr][i]['start'] and file2Dict[chr][j]['peak']<=file1Dict[chr][i]['stop']):
                        intersection_regions[chr][intersect]={}
                        intersection_regions[chr][intersect]['start']=max(file1Dict[chr][i]['start'],file2Dict[chr][j]['start'])
                        intersection_regions[chr][intersect]['stop']=min(file1Dict[chr][i]['stop'],file2Dict[chr][j]['stop'])
                        intersect_length=intersection_regions[chr][intersect]['stop']=intersection_regions[chr][intersect]['start']
                        region1_length=file1Dict[chr][i]['stop']-file1Dict[chr][i]['start']
                        fraction=(intersect_length+0.)/(region1_length+0.)
                        if fraction>=minOverlap:
                            intersect+=1
                            file1Dict[chr][i]['intersected']=1
                            file2Dict[chr][j]['intersected']=1
                        else:
                            del intersection_regions[chr][intersect]
                        break
                else:
                    if (file2Dict[chr][j]['start']>=file1Dict[chr][i]['start'] and file2Dict[chr][j]['start']<=file1Dict[chr][i]['stop']) or (file2Dict[chr][j]['stop']>=file1Dict[chr][i]['start'] and file2Dict[chr][j]['start']<=file1Dict[chr][i]['start']):
                        intersection_regions[chr][intersect]={}
                        intersection_regions[chr][intersect]['start']=max(file1Dict[chr][i]['start'],file2Dict[chr][j]['start'])
                        intersection_regions[chr][intersect]['stop']=min(file1Dict[chr][i]['stop'],file2Dict[chr][j]['stop'])
                        intersect_length=intersection_regions[chr][intersect]['stop']-intersection_regions[chr][intersect]['start']
                        region1_length=file1Dict[chr][i]['stop']-file1Dict[chr][i]['start']
                        fraction=(intersect_length+0.)/(region1_length+0.)
                        if fraction>=minOverlap:
                            intersect+=1
                            file1Dict[chr][i]['intersected']=1
                            file2Dict[chr][j]['intersected']=1
                        else:
                            del intersection_regions[chr][intersect]
                        break

    outersect1=0
    intersect1=0
    for chr in chrlist:
        for i in file1Dict[chr].keys():
            if file1Dict[chr][i]['intersected']==0:
                outersection1[chr][outersect1]=file1Dict[chr][i]['line']
                outersect1+=1
            if file1Dict[chr][i]['intersected']==1:
                intersection1[chr][intersect1]=file1Dict[chr][i]['line']
                intersect1+=1

    outersect2=0
    intersect2=0
    for chr in chrlist:
        for i in file2Dict[chr].keys():
            if file2Dict[chr][i]['intersected']==0:
                outersection2[chr][outersect2]=file2Dict[chr][i]['line']
                outersect2+=1
            if file2Dict[chr][i]['intersected']==1:
                intersection2[chr][intersect2]=file2Dict[chr][i]['line']
                intersect2+=1

    if File1Regions==0 or File2Regions==0:
        score1=0
        score2=0
        Jaccard = 0
    else:
        score1=intersect1/File1Regions
        score2=intersect2/File2Regions
        Jaccard = intersect1/(intersect1 + File1Regions - intersect1 + File2Regions - intersect2)

    return(score1,score2,Jaccard,File1Regions,File2Regions)

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s <list of files1> <list of files2> outfilename [-doPeak] [-minOverlap fraction] [-paired] [-Jaccard]' % sys.argv[0]
        print '\tthe first set of files will be compared against the second' 
        print '\tformat of list of files:' 
        print '\tlabel	regionCalls-filename	chrfield' 
        print '\tfiles can be in .gz, .bz2 or .zip format' 
        print '\tthe [-paired] option will only compare the first set from the first file with first set from the second, the second with the second, etc.'
        sys.exit(1)

    input1 = sys.argv[1]
    input2 = sys.argv[2]
    outputfilename = sys.argv[3]

    doPeak=False
    if '-doPeak' in sys.argv:
        doPeak=True

    doJaccard = False
    if '-Jaccard' in sys.argv:
        doJaccard = True

    minOverlap=0
    if '-minOverlap' in sys.argv:
        minOverlap=float(sys.argv[sys.argv.index('-minOverlap')+1])

    doPaired = False
    doPaired = 0
    if '-paired' in sys.argv:
        doPaired = True

    InputList1=[]
    DataDict={}
    ScoreDict={}
    linelist=open(input1)
    for line in linelist:
        fields=line.strip().split('\t')
        if len(fields)<2:
            continue
        ID = fields[0]
        file = fields[1]
        field = int(fields[2])
        DataDict[ID] = (file,field)
        ScoreDict[ID]={}
        InputList1.append(ID)

    InputList2=[]
    linelist=open(input2)
    for line in linelist:
        fields=line.strip().split('\t')
        if len(fields)<2:
            continue
        ID=fields[0]
        file=fields[1]
        field=int(fields[2])
        if DataDict.has_key(ID):
            ID = 'file2_' + ID
        DataDict[ID]=(file,field)
        for ID1 in InputList1:
            ScoreDict[ID1][ID]=0
        InputList2.append(ID)

    outfile = open(outputfilename, 'w')

    if doPaired:
        if len(InputList2) != len(InputList1):
            print 'different number of input files specified with [-paired] option, exiting'
            sys.exit(1)
        outline='#DataSet1\tDataSet2\tNumRegions1\tNumRegions2\tOverlap1\t'
        outfile.write(outline+'\n')
        RegionNumberDict={}
        for ID1 in InputList1:
            ID2 = InputList2[InputList1.index(ID1)]
#            print ID1, 'vs', ID2
            file1=DataDict[ID1][0]
            file2=DataDict[ID2][0]
            field1=DataDict[ID1][1]
            field2=DataDict[ID2][1]
            (score1,score2,Jaccard,File1Regions,File2Regions)=intersect(file1,field1,file2,field2,doPeak,minOverlap)
            if doJaccard:
                outline = ID1 + '\t' + ID2  + '\t' + str(File1Regions) + '\t' + str(File2Regions) + '\t' + str(Jaccard)
            else:
                outline = ID1 + '\t' + ID2  + '\t' + str(File1Regions) + '\t' + str(File2Regions) + '\t' + str(score1)
            print outline
            outfile.write(outline+'\n')
    else:
        RegionNumberDict={}
        for ID1 in InputList1:
            for ID2 in InputList2:
                print ID1, 'vs', ID2
                file1=DataDict[ID1][0]
                file2=DataDict[ID2][0]
                field1=DataDict[ID1][1]
                field2=DataDict[ID2][1]
                (score1,score2,Jaccard,File1Regions,File2Regions)=intersect(file1,field1,file2,field2,doPeak,minOverlap)
                if doJaccard:
                    ScoreDict[ID1][ID2]=Jaccard
                else:
                    ScoreDict[ID1][ID2]=score1
                RegionNumberDict[ID1]=File1Regions
                RegionNumberDict[ID2]=File2Regions
        InputList1.sort()
        InputList2.sort()

        outline='#DataSet:\t'
        for ID in InputList2:
            outline=outline+'\t'+ID
        outfile.write(outline+'\n')
        outline='\tRegions'
        for ID in InputList2:
            outline=outline+'\t'+str(int(RegionNumberDict[ID]))
        outfile.write(outline+'\n')
        for ID1 in InputList1:
            outline=ID1+'\t'+str(int(RegionNumberDict[ID1]))
            for ID2 in InputList2:
                outline=outline+'\t'+str(ScoreDict[ID1][ID2])
            outfile.write(outline+'\n')

    outfile.close()

run()

