##################################
#                                #
# Last modified 06/15/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
from scipy.stats import hypergeom

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s whole_proteome_PFAM_tblout-tab PFAMtoGO gene_list fieldID outfile [-MaxEvalue value] [-splitPFAMGeneID]' % sys.argv[0]
        sys.exit(1)

    PFAM = sys.argv[1]
    PFAMtoGO = sys.argv[2]
    genelist = sys.argv[3]
    fieldID = int(sys.argv[4])
    outfilename = sys.argv[5]

    doMaxEvalue = False
    if '-MaxEvalue' in sys.argv:
        doMaxEvalue = True
        MaxEvalue = float(sys.argv[sys.argv.index('-MaxEvalue') + 1])

    doSPFAMGID = False
    if '-splitPFAMGeneID' in sys.argv:
        doSPFAMGID = True

    geneInListDict = {}
    linelist = open(genelist)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = fields[fieldID]
        geneInListDict[gene] = 1

    print 'genes in list:', len(geneInListDict.keys())

    PFAMtoGODict = {}

    linelist = open(PFAMtoGO)
    for line in linelist:
        if line.startswith('!'):
            continue
        domain = line.strip().split(' > ')[0].split(' ')[1]
        GOID = line.strip().split(' > ')[1].split(' ; ')[1]
        GODescription = line.strip().split(' > ')[1].split(' ; ')[0]
        if PFAMtoGODict.has_key(domain):
            pass
        else:
            PFAMtoGODict[domain] = []
        PFAMtoGODict[domain].append((GOID,GODescription))

    print 'PFAMtoGO domains:', len(PFAMtoGODict.keys())

#    GeneDomainDict = {}
    GOCatDict = {}
    TotalGenesConsidered = {}

    geneInListWithDomains = {}

    linelist = open(PFAM)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = fields[2]
        if doSPFAMGID:
            gene = gene.split('|')[0]
        domain = fields[0]
        if doMaxEvalue:
            evalue = float(fields[4])
            if evalue > MaxEvalue:
                continue
#        if GeneDomainDict.has_key(gene):
#            pass
#        else:
#            GeneDomainDict[gene] = []
#        GeneDomainDict[gene].append(domain)
        if PFAMtoGODict.has_key(domain):
            if geneInListDict.has_key(gene):
                geneInListWithDomains[gene] = 1
            TotalGenesConsidered[gene] = 1
            for (GOID,GODes) in PFAMtoGODict[domain]:
                if GOCatDict.has_key((GOID,GODes)):
                    pass
                else:
                    GOCatDict[(GOID,GODes)] = []
                GOCatDict[(GOID,GODes)].append(gene)

    TotalCats = len(GOCatDict.keys())
    print 'Total GO categories:', TotalCats

    TotalGenes = len(TotalGenesConsidered.keys())
    print 'Total genes considered:', TotalGenes

    NumGeneInListWithDomains = len(geneInListWithDomains.keys())
    print 'genes in list with domains:', NumGeneInListWithDomains

    outfile = open(outfilename,'w')
    outline = '#N\tX\tP\tP_adj\tattrib ID\tattrib name'
    outfile.write(outline + '\n')

    for (GOID,GODes) in GOCatDict.keys():
        GOCatDict[(GOID,GODes)] = list(Set(GOCatDict[(GOID,GODes)]))
        InList = 0
        InCat = len(GOCatDict[(GOID,GODes)])
        for gene in GOCatDict[(GOID,GODes)]:
            if geneInListDict.has_key(gene):
                InList += 1
        rv = hypergeom(TotalGenes,InCat,NumGeneInListWithDomains)
        P = 1 - rv.cdf(InList)
        P_adj = P*TotalCats
        print TotalCats, TotalGenes, InCat, NumGeneInListWithDomains, InList, GOID, GODes, P, P_adj
        if P_adj <= 0.05:
            outline = str(InList) + '\t' + str(InCat) + '\t' + str(P) + '\t' + str(P_adj) + '\t' + GOID + '\t' + GODes
            outfile.write(outline + '\n')

    outfile.close()

run()

