##################################
#                                #
# Last modified 2022/12/16       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s file1 chromField1 file2 chromField2 outfilenameprefix [-left1 fieldID] [-left2 fieldID] [-right1 fieldID] [-right2 fieldID] [-minOverlap fraction-of-2nd-file-region] [-noOS] [-minOverlapBP bp] [-combinedOutput]' % sys.argv[0]
        print "\tNote: have file1 be the smaller file; the script will store its coordinates in the memory in a dictionary and check against the other file as it goes over its entries"
        print "\tNote: works with 0bp-sized regions"
        print "\tNote: works with .gz, .bgz and .bzip2 files"
        print "\tNote: if you use the minimal overlap option, only the regions from the second files will be intersected under that criteria, i.e. what is in the intersection2 files; regions in intersection1 will be intersected using the 1bp overlap crtieria"
        print "\tUse the [-noOS] option if you do not want the outersecs to be written to disk"
        sys.exit(1)

    file1 = sys.argv[1]
    file2 = sys.argv[3]
    outfileprefix = sys.argv[5]
    minOverlap=0
    if '-minOverlap' in sys.argv:
        minOverlap=float(sys.argv[sys.argv.index('-minOverlap') + 1])
        print 'will require overlap of', minOverlap, 'fraction of entries in the second file'
    chromField1 = int(sys.argv[2])
    chromField2 = int(sys.argv[4])

    L1FID = chromField1 + 1
    R1FID = chromField1 + 2
    L2FID = chromField2 + 1
    R2FID = chromField2 + 2

    if '-left1' in sys.argv:
        L1FID = int(sys.argv[sys.argv.index('-left1') + 1])
   
    if '-left2' in sys.argv:
        L2FID = int(sys.argv[sys.argv.index('-left2') + 1])

    if '-right1' in sys.argv:
        R1FID = int(sys.argv[sys.argv.index('-right1') + 1])
   
    if '-right2' in sys.argv:
        R2FID = int(sys.argv[sys.argv.index('-right2') + 1])

    if '-minOverlapBP' in sys.argv:
        doMinBasePairOverlap = True
        minOverlapBP = int(sys.argv[sys.argv.index('-minOverlapBP') + 1])
        print 'will require overlap of', minOverlapBP, 'bp of entries in the second file'

    doOS = True
    if '-noOS' in sys.argv:
        doOS = False
        print 'will not print out outersects'

    doMinBasePairOverlap = False
    if '-minOverlapBP' in sys.argv:
        doMinBasePairOverlap = True
        minOverlapBP = int(sys.argv[sys.argv.index('-minOverlapBP') + 1])
        print 'will require overlap of', minOverlapBP, 'bp of entries in the second file'

    File1DictCoverage={}
    File1Dict={}

    if file1.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + file1
    elif file1.endswith('.gz'):
        cmd = 'zcat ' + file1
    elif file1.endswith('.bgz'):
        cmd = 'zcat ' + file1
    else:
        cmd = 'cat ' + file1
    p = os.popen(cmd, "r")
    line = 'line'
    i=0
    while line != '':
        line = p.readline()
        i+=1
        if i % 100000 == 0:
            print str(i/1000000) + 'M lines in file1 processed'
        if line.startswith('#') or line.startswith('track type'):
            continue
        if len(line.strip())==0:
            continue
        fields=line.strip().split('\t')
        chr=fields[chromField1]
        if File1DictCoverage.has_key(chr):
            pass
        else:
            File1DictCoverage[chr]={}
        try:
            start = int(fields[L1FID])
            end = int(fields[R1FID])
        except:
            print 'skipping line'
            print line.strip()
            continue
        for j in range(start,end):
            if File1DictCoverage[chr].has_key(j):
                pass
            else:
                File1DictCoverage[chr][j]={}
            File1DictCoverage[chr][j][i]=0
        if end == start:
            if File1DictCoverage[chr].has_key(start):
                pass
            else:
                File1DictCoverage[chr][start]={}
            File1DictCoverage[chr][start][i]=0
        File1Dict[i]=line

    print 'finished parsing file1'

    OverlappedListDict={}

    outfilename_intersection1=open(outfileprefix+'-intersection1','w')
    outfilename_intersection2=open(outfileprefix+'-intersection2','w')
    if doOS:
        outfilename_outersection1=open(outfileprefix+'-outersection1','w')
        outfilename_outersection2=open(outfileprefix+'-outersection2','w')

    if file2.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + file2
    elif file2.endswith('.gz'):
        cmd = 'zcat ' + file2
    elif file2.endswith('.bgz'):
        cmd = 'zcat ' + file2
    else:
        cmd = 'cat ' + file2
    p = os.popen(cmd, "r")
    line = 'line'
    k = 0 
    while line != '':
        line = p.readline()
        if line == '':
            break
        k+=1
        if k % 100000 == 0:
            print str(k/1000000.) + 'M lines in file2 processed'
        if line.startswith('#') or line.startswith('track type'):
            continue
        fields=line.strip().split('\t')
        try:
            chr=fields[chromField2].split(':')[0]
            start = int(fields[L2FID])
            end = int(fields[R2FID])
        except:
            print 'skipping line'
            print line.strip()
            continue
        if File1DictCoverage.has_key(chr):
            pass
        else:
            if doOS:
                outfilename_outersection2.write(line)
            continue
        overlapBP=0
        for i in range(start,end):
            if File1DictCoverage[chr].has_key(i):
                for r in File1DictCoverage[chr][i].keys():
                    OverlappedListDict[r]=0
                overlapBP+=1
        if doMinBasePairOverlap:
            if overlapBP >= minOverlapBP:
                outfilename_intersection2.write(line)
            else:
                if doOS:
                    outfilename_outersection2.write(line)
        else:
            if overlapBP != 0 and (overlapBP/(end - start + 0.0) >= minOverlap):
                outfilename_intersection2.write(line)
            else:
                if doOS:
                    outfilename_outersection2.write(line)

    print OverlappedListDict.keys()
    print len(OverlappedListDict.keys())

    for i in File1Dict.keys():
        if OverlappedListDict.has_key(i):
            outfilename_intersection1.write(File1Dict[i])
        else:
            if doOS:
                outfilename_outersection1.write(File1Dict[i])

    outfilename_intersection1.close()
    outfilename_intersection2.close()
    if doOS:
        outfilename_outersection1.close()
        outfilename_outersection2.close()

run()
