##################################
#                                #
# Last modified 01/15/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s fasta outputfolder' % sys.argv[0]
        print '\tNote: The script will look for the last comma and if there is a space-separated number before it,'
        print '\tit will assume that the prefix prior to that indicates the common taxonomic entity the multiple'
        print '\tentries for which are to be combined. Make sure to manually check the output for inconsistencies in parsing'
        sys.exit(1)

    fasta = sys.argv[1]
    outfolder = sys.argv[2]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line.strip() == '':
            continue
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    NewSeqDict = {}

    for ID in GenomeDict.keys():
        commonID = ID
#        if 'Aspergillus' in ID:
#            print '1', commonID
        if ' chromosome segment' in ID:
            commonID = commonID.split(' chromosome segment')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
#            if 'Aspergillus' in ID:
#                print '2', commonID
        elif ' segment' in ID:
            commonID = commonID.split(' segment')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
#            if 'Aspergillus' in ID:
#                print '3', commonID
        elif ' dsRNA' in ID and 'virus' not in ID.split(' dsRNA')[-1]:
            commonID = commonID.split(' dsRNA')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
        elif ' RNA' in ID and 'virus' not in ID.split(' RNA')[-1]:
            commonID = commonID.split(' RNA')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
#            if 'Aspergillus' in ID:
#                print '5', commonID
        elif ' DNA' in ID and 'virus' not in ID.split(' DNA')[-1]:
            commonID = commonID.split(' DNA')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
        elif ' ORF' in ID:
            commonID = commonID.split(' ORF')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
#            if 'Aspergillus' in ID:
#                print '6', commonID
        else:
            commonID = commonID.split(', complete')[0]
            if commonID.endswith(','):
                commonID = commonID[0:-1]
#            if 'Aspergillus' in ID:
#                print '7', commonID
        commonID = commonID.split('| ')[1]
        commonID = commonID.replace('/','-')
        commonID = commonID.replace(';','-')
        commonID = commonID.replace("'",'')
        commonID = commonID.replace('(','')
        commonID = commonID.replace(')','')
        commonID = commonID.replace('[','')
        commonID = commonID.replace(']','')
        commonID = commonID.replace(':','-')
        commonID = commonID.replace('|','__')
#        if 'Aspergillus' in ID:
#            print '8', commonID
        if NewSeqDict.has_key(commonID):
            pass
        else:
            NewSeqDict[commonID] = {}
        seq = GenomeDict[ID]
        NewSeqDict[commonID][ID] = seq

    print len(GenomeDict.keys()), len(NewSeqDict.keys())

    for G in NewSeqDict.keys():
        outputfilename = outfolder + '/' + G + '.fasta'
        outfile = open(outputfilename.replace(' ','_'), 'w')
        for ID in NewSeqDict[G].keys():
            outline = '>' + ID
            outfile.write(outline + '\n')
            sequence = NewSeqDict[G][ID]
            blocksize = 50
            for i in range(0,len(sequence),blocksize):
                outfile.write(sequence[i:min(i+blocksize, len(sequence))] + '\n')
        outfile.close()

run()

