##################################
#                                #
# Last modified 02/21/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s query_config target_config [-minAlignmentLength fieldID] [-catSequences fieldID] [-noX] [-noXX] [-noFirstX]' % sys.argv[0]
        print '\tquery_config format:'
        print '\t\tblast_tabular_file\tquery.fa\tlabel\tgroup_label'
        print '\ttarget_config format:'
        print '\t\tgroup_label\tname'
        print '\tNote: the script will only consider names in fasta files up to the first space or a tab'
        print '\tNote: if you use the -minAlignmentLength, the fieldID is the column ID in the target_config file where the length is specified'
        print '\tNote: the -catSequences allows the addition of other sequences for each target group; the fieldID specifies in which column ID in the target_config file their location is listed'
        print '\tNote: the -noFirstX option will remove all protein sequences that start with an X'
        print '\tNote: the -noXX option will remove all protein sequences that have a run of more than one consecutive X in them'
        print '\tNote: the -noX option will remove all protein sequences that have an X residue in them; it overrides the previous two options'
        sys.exit(1)

    qconfig = sys.argv[1]
    tconfig = sys.argv[2]

    NoFirstX = False
    if '-noFirstX' in sys.argv:
        NoFirstX = True

    NoXX = False
    if '-noXX' in sys.argv:
        NoXX = True

    NoX = False
    if '-noX' in sys.argv:
        NoX = True

    doMinAlign = False
    if '-minAlignmentLength' in sys.argv:
        doMinAlign = True
        MinAlignFieldID = int(sys.argv[sys.argv.index('-minAlignmentLength') + 1])
        MinAlignDict = {}

    doCatSeq = False
    if '-catSequences' in sys.argv:
        doCatSeq = True
        CatSeqFieldID = int(sys.argv[sys.argv.index('-catSequences') + 1])
        CatSeqDict = {}


    TargetDict = {}
    linelist = open(tconfig)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        GroupLabel = fields[0]
        name = fields[1]
        if TargetDict.has_key(name) and TargetDict[name] != GroupLabel:
            print 'conflicting target group assignments detected, exiting'
            sys.exit(1)
        TargetDict[name] = GroupLabel
        if doMinAlign:
            MinAlignDict[name] = int(fields[MinAlignFieldID])
        if doCatSeq:
            CatSeqDict[GroupLabel] = fields[CatSeqFieldID]
   
    QueryDict = {}
    linelist = open(qconfig)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        blasttab = fields[0]
        query = fields[1]
        label = fields[2]
        group = fields[3]
        if QueryDict.has_key(group):
            pass
        else:
            QueryDict[group] = {}
        QueryDict[group][(label,blasttab,query)] = {}

    for group in QueryDict.keys():
        print group
        WantedDict = {}
        TGroups = []
        for (label,blasttab,query) in QueryDict[group].keys():
            WantedDict[label] = {}
            linelist = open(blasttab)
            for line in linelist:
                if line.startswith('#') or line.strip() == '':
                    continue
                fields = line.strip().split('\t') 
                qname = fields[0]
                tname = fields[1]
                if doMinAlign:
                    AlignLen = int(fields[3])
                    if AlignLen < MinAlignDict[tname]:
                        continue
                WantedDict[label][qname] = {}
                WantedDict[label][qname]['group'] = TargetDict[tname]
                TGroups.append(TargetDict[tname])
            linelist = open(query)
            Keep = False
            for line in linelist:
                if line.startswith('>'):
                    if Keep:
                        WantedDict[label][ID]['seq'] = ''.join(seq)
                    ID = line.strip().split('>')[1].split(' ')[0].split('\t')[0]
                    seq = []
                    if WantedDict[label].has_key(ID):
                        Keep = True
                    else:
                        Keep = False
                    continue
                else:
                    if Keep:
                        seq.append(line.strip())
                    else:
                        continue
        TGroups = list(Set(TGroups))
        for TGroup in TGroups:
            outfilename = group + '.' + TGroup + '.fa'
            outfile = open(outfilename,'w')
            for label in WantedDict.keys():
                for qname in WantedDict[label].keys():
                    if WantedDict[label][qname]['group'] == TGroup:
                        try:
                            seq = WantedDict[label][qname]['seq']
                        except:
                            print 'problem with:', label, qname, 'skipping'
                            continue
                        if NoFirstX:
                            if seq.startswith('X'):
                                continue
                        if NoXX:
                            if 'XX' in seq:
                                continue
                        if NoX:
                            if 'X' in seq:
                                continue
                        outline = '>' + label + '_' + qname
                        outfile.write(outline + '\n')
                        for i in range(0,len(seq),150):
                            outfile.write(seq[i:min(i+150, len(seq))] + '\n')
            if doCatSeq:
                linelist = open(CatSeqDict[TGroup])
                for line in linelist:
                    outfile.write(line)
            outfile.close()

run()

