##################################
#                                #
# Last modified 12/03/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s input geneFieldID list_of_GO_categories funcassociate_go_associations.txt outprefix' % sys.argv[0]
        print '\tDo not worry about gene name capitalization, the script will only compare names converted to all caps'
        print '\tlist_of_GO_categories format: one entry per line listing the ID like this: GO:0019815'
        print '\tfuncassociate_go_associations.txt format: GO:0000009	alpha-1,6-mannosyltransferase activity	ALG12 ALG2'
        sys.exit(1)

    input = sys.argv[1]
    fieldID = int(sys.argv[2])
    GOentries = sys.argv[3]
    GOannotation = sys.argv[4]
    outprefix = sys.argv[5]

    WantedGODict = {}
    lineslist = open(GOentries)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        WantedGODict[fields[0]]=''

    GODict = {}
    lineslist = open(GOannotation)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        GOID = fields[0]
        if WantedGODict.has_key(GOID):
            pass
        else:
            continue
        GOname = fields[1].replace(' ','_').replace('/','-')
        genes = fields[2].split(' ')
        GODict[(GOID,GOname)] = {}
        for gene in genes:
            GODict[(GOID,GOname)][gene] = 0

    for (GOID,GOname) in GODict.keys():
        print GOID,GOname
        outfileIn = open(outprefix + '.' + GOID.replace(':','') + '.' + GOname + '.in', 'w')
        outfileOut = open(outprefix + '.' + GOID.replace(':','') + '.' + GOname + '.out', 'w')
        linelist = open(input)
        for line in linelist:
            if line.startswith('#'):
                outfileIn.write(line)
                continue
            fields = line.strip().split('\t')
            geneID = fields[fieldID].upper()
#            print fields[fieldID], geneID
            if GODict[(GOID,GOname)].has_key(geneID):
                outfileIn.write(line)
                GODict[(GOID,GOname)][geneID] = 1
        outfileIn.close()
        for geneID in GODict[(GOID,GOname)].keys():
            if GODict[(GOID,GOname)][geneID] == 0:
                outfileOut.write(geneID + '\n')
        outfileOut.close()

run()
