##################################
#                                #
# Last modified 2016/12/31       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import copy
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s cog.csv cog-prot.fa outfile_prefix [-useALL 0,1,2,3] [-speciesNumber N]' % sys.argv[0]
        print '\tNote: the script will output only a single COG match per species, and it will be the best hit'
        print '\t\tuse the [-useALL] option if you want all hits'
        sys.exit(1)

    cogs = sys.argv[1]
    fasta = sys.argv[2]
    outprefix = sys.argv[3]

    doFSN = False
    if '-speciesNumber' in sys.argv:
        doFSN = True
        FSN = int(sys.argv[sys.argv.index('-speciesNumber') + 1])

    doUseAll = False
    if '-useALL' in sys.argv:
        doUseAll = True
        COGMatchTypes = sys.argv[sys.argv.index('-useALL') + 1].split(',')
        if '0' not in COGMatchTypes:
            COGMatchTypes.append('0')

    ProteinDict = {}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                ProteinDict[protein] = ''.join(sequence)
            protein = line.strip().split('>')[1].split('|')[1].split(' ')[0]
            sequence=[]
            continue
        else:
            sequence.append(line.strip())
    ProteinDict[protein] = ''.join(sequence)

    print 'finished inputting protein sequences'

    SpeciesDict = {}
    COGDict = {}
    
    linelist=open(cogs)
    for line in linelist:
        fields = line.strip().split(',')
        if len(fields) < 8:
            continue
        species = fields[1]
        SpeciesDict[species] = 1
        proteinID = fields[0]
        COG = fields[6]
        if COGDict.has_key(COG):
            pass
        else:
            COGDict[COG] = {}
        COGmatch = fields[7]
        if COGDict[COG].has_key(species):
            pass
        else:
            COGDict[COG][species] = []
        COGDict[COG][species].append((proteinID,COGmatch))

    print 'finished inputting COGs'

    COGs = COGDict.keys()
    COGs.sort()
    speciesList = SpeciesDict.keys()
    speciesList.sort()

    if doFSN:
        pass
    else:
        FSN = len(speciesList)
    print 'will require single-copy ortholog state in', FSN, 'species'

    for COG in COGs:
        SingleCopyIn = 0
        for species in speciesList:
#            print COG, species, COGDict[COG][species]
            if COGDict[COG].has_key(species) and len(COGDict[COG][species]) == 1:
                (proteinID,COGmatch) = COGDict[COG][species][0]
                if doUseAll:
                    if COGmatch in COGMatchTypes:
                        SingleCopyIn += 1
                    else:
                        SingleCopyEveryWhere = False
#                        print COG, species, (proteinID,COGmatch)
#                        break
                else:
                    if COGmatch != '0':
                        SingleCopyEveryWhere = False
#                        print COG, species, (proteinID,COGmatch)
#                        break
                    else:
                        SingleCopyIn += 1
            else:
                SingleCopyEveryWhere = False
#                print COG, species
#                break
#        print COG, SingleCopyIn
        if SingleCopyIn >= FSN:
            print COG
            outfile = open(outprefix + '.' + COG + '.fa', 'w')
            for species in speciesList:
                if COGDict[COG].has_key(species):
                     outline = '>' + species
                     outfile.write(outline + '\n')
                     (proteinID,COGmatch) = COGDict[COG][species][0]
                     outline = ProteinDict[proteinID]
                     outfile.write(outline + '\n')
            outfile.close()

run()
