#
#  intersects.py
#  ENRAGE
#

import sys
import optparse
from commoncode import getConfigParser, getConfigOption, getConfigBoolOption, getConfigIntOption

print "intersects: version 2.1"

def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog infile1 infile2 outfile [options]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(1)

    infile1 = args[0]
    infile2 = args[1]
    outfile = args[2]

    intersects(infile1, infile2, outfile, options.delimiter, options.infile3,
               options.matchField1, options.matchField2, options.matchField3,
               options.rejectFileName, options.trackGID)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("-d", dest="delimiter")
    parser.add_option("--file3", dest="infile3")
    parser.add_option("-1", type="int", dest="matchfield1")
    parser.add_option("-2", type="int", dest="matchfield2")
    parser.add_option("-3", type="int", dest="matchfield3")
    parser.add_option("-reject1", dest="reject1file")
    parser.add_option("-trackGID", action="store_true", dest="trackGID")

    configParser = getConfigParser()
    section = "geneMrnaCountsWeighted"
    delimiter = getConfigOption(configParser, section, "delimiter", "\t")
    infile3 = getConfigOption(configParser, section, "infile3", None)
    matchField1 = getConfigIntOption(configParser, section, "matchField1", 0)
    matchField2 = getConfigIntOption(configParser, section, "matchField2", 0)
    matchField3 = getConfigIntOption(configParser, section, "matchField3", 0)
    rejectFileName = getConfigOption(configParser, section, "rejectFileName", "\t")
    trackGID = getConfigBoolOption(configParser, section, "trackGID", False)

    parser.set_defaults(delimiter=delimiter, infile3=infile3, matchField1=matchField1, matchField2=matchField2,
                        matchField3=matchField3, rejectFileName=rejectFileName, trackGID=trackGID)

    return parser


def intersects(infile1Name, infile2Name, outfileName, delimiter="\t", infile3Name=None,
               matchField1=0, matchField2=0, matchField3=0, rejectFileName="", trackGID=False):

    if rejectFileName:
        doReject1 = True
        reject1file = open(rejectFileName)
    else:
        doReject1 = False

    if infile3Name is not None:
        doFile3 = True
    else:
        doFile3 = False

    matchedList = []
    matchedList12 = []
    matchedList13 = []
    matchedList23 = []
    gidDict = {}

    if trackGID:
        gidKeys = gidDict.keys()
        list1, fileGIDDict = getCandidatesAndGIDFromFile(infile1Name, delimiter, matchField1, gidKeys)
        for entry in fileGIDDict.keys():
            gidDict[entry] = fileGIDDict[entry]

        gidKeys = gidDict.keys()
        list2, fileGIDDict = getCandidatesAndGIDFromFile(infile2Name, delimiter, matchField2, gidKeys)
        for entry in fileGIDDict.keys():
            gidDict[entry] = fileGIDDict[entry]
            
        if doFile3:
            gidKeys = gidDict.keys()
            list3, fileGIDDict = getCandidatesAndGIDFromFile(infile3Name, delimiter, matchField3, gidKeys)
            for entry in fileGIDDict.keys():
                gidDict[entry] = fileGIDDict[entry]
    else:
        list1 = getCandidateListFromFile(infile1Name, delimiter, matchField1)
        list2 = getCandidateListFromFile(infile2Name, delimiter, matchField2)
        if doFile3:
            list3 = getCandidateListFromFile(infile3Name, delimiter, matchField3)

    for candidate in list1:
        if doFile3 and candidate in list2 and candidate in list3:
            matchedList.append(candidate)
        elif doFile3 and candidate in list3:
            matchedList13.append(candidate)
        elif doFile3 and candidate in list2:
            matchedList12.append(candidate)
        elif not doFile3 and candidate in list2:
            matchedList.append(candidate)
        elif doReject1:
            if trackGID:
                reject1file.write("%s%s%s\n" % (candidate, delimiter, gidDict[candidate]))
            else:
                reject1file.write("%s\n" % candidate)

    if doFile3:
        for candidate in list2:
            if candidate not in list1 and candidate in list3:
                matchedList23.append(candidate)

    print len(list1), len(list2), len(list3)
    if doFile3:
        print len(matchedList12), len(matchedList13), len(matchedList23)
    print len(matchedList)

    outfile = open(outfileName, "w")
    for match in matchedList:
        if trackGID:
            outfile.write("%s%s%s\n" % (match, delimiter, gidDict[match]))
        else:
            outfile.write("%s\n" % match)

    outfile.close()


def getCandidatesFromFile(filename, delimiter, matchField, trackGID=False, gidList=[]):
    infile = open(filename)
    candidateList = []
    gidDict = {}

    for line in infile:
        if line[0] == "#":
            continue

        fields = line.strip().split(delimiter)
        candidate = fields[matchField]
        if candidate not in candidateList:
            candidateList.append(candidate)

        if trackGID and candidate not in gidList:
            gidDict[candidate] = fields[matchField + 1]

    infile.close()
    return candidateList, gidDict


def getCandidatesAndGIDFromFile(filename, delimiter, matchField, gidList=[]):
    return getCandidatesFromFile(filename, delimiter, matchField, trackGID=True, gidList=[])


def getCandidateListFromFile(filename, delimiter, matchField):
    candidateList, gidDict = getCandidatesFromFile(filename, delimiter, matchField)
    return candidateList


if __name__ == "__main__":
    main(sys.argv)