##################################
#                                #
# Last modified 04/11/2016       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s list_of_HOG_fasta_files OrthologousGroups.txt outfilename [-maxSize M]' % sys.argv[0]
        sys.exit(1)

    FFlist = sys.argv[1]
    OG = sys.argv[2]
    outfilename = sys.argv[3]

    SpeciesDict = {}

    doMax = False
    if '-maxSize' in sys.argv:
        doMax = True
        MS = int(sys.argv[sys.argv.index('-maxSize') + 1])

    OGDict = {}
    linelist = open(OG)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        SGlist = []
        for i in range(1,len(fields)):
            gene = fields[i].split(':')[1]
            species = fields[i].split(':')[0]
            SGlist.append((gene,species))
        for (gene,species) in SGlist:
            if OGDict.has_key(species):
                pass
            else:
                OGDict[species] = {}
            OGDict[species][gene] = {}
            for (gene2,species2) in SGlist:
                if (gene2,species2) == (gene,species):
                    continue
                OGDict[species][gene][species2] = gene2

    print 'finished inputting OG'
    HOGDict={}
    linelist = open(FFlist)
    for fileline in linelist:
        if line.startswith('#'):
            continue
        F = fileline.split('/')[-1].split('.fa')[0]
        if HOGDict.has_key(F):
            print 'duplicate file names detected, exiting'
            print F
            sys.exit(1)
        HOGDict[F] = {}
        fasta = fileline.strip().split('\t')[0]
        inputdatafile = open(fasta)
        for line in inputdatafile:
            if line[0]=='>':
                species = line.strip().split(']')[0].split('[')[1]
                gene = line.strip().split('>')[1].split('[')[0].strip()
                if HOGDict[F].has_key(species):
                    pass
                else:
                    HOGDict[F][species] = []
                HOGDict[F][species].append(gene)
                if SpeciesDict.has_key(species):
                    pass
                else:
                    SpeciesDict[species] = 1
                if SpeciesDict[species] < len(HOGDict[F][species]):
                    SpeciesDict[species] = len(HOGDict[F][species])

    print 'finished inputting HOG'

    outfile = open(outfilename, 'w')

    HOGs = HOGDict.keys()
    HOGs.sort()
    SpeciesList = SpeciesDict.keys()
    SpeciesList.sort()

    maxS = 0
    for species in SpeciesList:
        maxS = max(maxS,SpeciesDict[species])
    if doMax:
        maxS = MS

    outline = '#HOG'
    for species in SpeciesList:
        for i in range(maxS):
            outline = outline + '\t' + species
    outfile.write(outline + '\n')

    for F in HOGs:
        outline = F
        species = SpeciesList[0]
        newHOGDict = {}
        for species in SpeciesList:
            newHOGDict[species] = []
        maxGenes = 0
        maxGenesSpecies = ''
        for species in SpeciesList:
            if HOGDict[F].has_key(species):
                pass
            else:
                continue
            if len(HOGDict[F][species]) > maxGenes:
                maxGenes = len(HOGDict[F][species])
                maxGenesSpecies = species
        if doMax:
            if maxGenes > MS:
                continue
        HOGDict[F][maxGenesSpecies].sort()
        for gene in HOGDict[F][maxGenesSpecies]:
            newHOGDict[maxGenesSpecies].append(gene)
            for species in SpeciesList:
                if species != maxGenesSpecies:
                    pass
                else:
                    continue
                if OGDict[maxGenesSpecies].has_key(gene) and OGDict[maxGenesSpecies][gene].has_key(species):
                    newHOGDict[species].append(OGDict[maxGenesSpecies][gene][species])
                else:
                    newHOGDict[species].append('-')
        for species in SpeciesList:
            for i in range(maxS):
                if i < len(newHOGDict[species]):
#                     print F, maxGenesSpecies, species, maxS, i, len(newHOGDict[species])
                     outline = outline + '\t' + newHOGDict[species][i]
                else:
                     outline = outline + '\t' + '-'
        outfile.write(outline + '\n')

    outfile.close()

run()
