#
#  siteintersects.py
#  ENRAGE
#

import sys
from commoncode import regionsOverlap

print "siteintersects: version 2.1"


class Site():
    def __init__(self, line, doExpanded=False):
        fields = line.strip().split()
        if doExpanded:
            self.chrom = fields[1][3:]
            self.start = int(fields[2])
            self.stop = int(fields[3])
            self.rest = fields[4:]
        else:
            (chromosome, pos) = fields[0].split(":")
            self.chrom = chromosome[3:]
            (start, stop) = pos.split("-")
            self.start = int(start)
            self.stop = int(stop)
            self.rest = fields[1:]



def main(argv=None):
    if not argv:
        argv = sys.argv

    if len(argv) < 4:
        print "usage: python %s sitefile1 sitefile2 outfile [--reject rejectfile1 rejectfile2] [--expanded]" % argv[0]
        sys.exit(1)

    sitefilename1 =  argv[1]
    sitefilename2 = argv[2]
    outfilename = argv[3]

    doReject = False
    if "--reject" in sys.argv:    
        reject1file = open(sys.argv[sys.argv.index("-reject") + 1], "w")
        reject2file = open(sys.argv[sys.argv.index("-reject") + 2], "w")
        doReject = True

    doExpanded = False
    if "--expanded" in sys.argv:
        doExpanded = True

    siteintersects(sitefilename1, sitefilename2, outfilename, reject1file, reject2file, doReject, doExpanded)


def siteintersects(siteFilename, siteCompareFilename, outfilename, siteRejectFilename=None, compareRejectFilename=None, doReject=False, doExpanded=False):

    siteDict, rejectSiteDict = getSiteDicts(siteFilename, doExpanded, doReject)
    commonSiteCount = compareSites(siteCompareFilename, compareRejectFilename, outfilename, siteDict, rejectSiteDict, doExpanded, doReject)

    if doReject and siteRejectFilename is not None:
        writeRejectSiteFile(siteRejectFilename, rejectSiteDict)

    print commonSiteCount


def getSiteDicts(sitefilename, doExpanded=False, doReject=False):
    siteDict = {}
    rejectSiteDict = {}
    processedLineCount = 0
    infile = open(sitefilename)
    infile.readline()
    for line in infile:
        if doNotProcessLine(line):
            continue

        processedLineCount += 1
        site = Site(line, doExpanded)
        chrom = site.chrom
        start = site.start
        stop = site.stop
        rest = site.fieldList

        try:
            siteDict[chrom].append((start, stop, rest))
        except KeyError:
            siteDict[chrom] = [(start, stop, rest)]

        if doReject:
            rejectSiteDict[str((chrom, start, stop, rest))] = line

    infile.close()
    print "file1: %d" % processedLineCount

    return siteDict, rejectSiteDict


def doNotProcessLine(line):
    return line[0] == "#"


def compareSites(siteFilename, rejectFilename, outfilename, siteDict, rejectSiteDict, doExpanded, doReject):
    processedLineCount = 0
    infile = open(siteFilename)
    infile.readline()

    commonSites = 0
    if doReject and rejectFilename is not None:
        rejectfile = open(rejectFilename, "w")
    else:
        doReject=False

    outfile = open(outfilename, "w")
    for line in infile:
        if doNotProcessLine(line):
            continue

        processedLineCount += 1
        site = Site(line, doExpanded)
        chrom = site.chrom
        start = site.start
        stop = site.stop
        if chrom not in siteDict:
            if doReject:
                rejectfile.write(line)

            continue

        siteNotCommon = True
        for (rstart, rstop, rline) in siteDict[chrom]:
            if regionsOverlap(start, stop, rstart, rstop):
                commonSites += 1
                if siteNotCommon:
                    outfile.write("common%d\tchr%s\t%d\t%d\t%s\tchr%s\t%d\t%d\t%s\n" % (commonSites, chrom, rstart, rstop, str(rline), chrom, start, stop, site.fieldList))
                    siteNotCommon = False

                try:
                    if doReject:
                        del rejectSiteDict[str((chrom, rstart, rstop, rline))]
                except KeyError:
                    pass

        if doReject and siteNotCommon:
            rejectfile.write(line)

    if doReject:
        rejectfile.close()

    outfile.close()

    print "file2: %d" % processedLineCount

    return commonSites


def writeRejectSiteFile(siteRejectFilename, rejectSiteDict):
    rejectfile = open(siteRejectFilename, "w")

    for key in rejectSiteDict:
        rejectfile.write(rejectSiteDict[key])

    rejectfile.close()
    

if __name__ == "__main__":
    main(sys.argv)
