##################################
#                                #
# Last modified 2019/12/01       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import random

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s wanted fieldID fasta outputfilename [-random fraction] [-prefix string] [-splitspace]' % sys.argv[0]
        print '\tNote: the [-random] option will disregard the wanted file and will just sample the set of sequences randomly'
        sys.exit(1)

    doRandom = False
    if '-random' in sys.argv:
        doRandom = True
        RF = float(sys.argv[sys.argv.index('-random') + 1])

    wanted = sys.argv[1]
    fieldID = int(sys.argv[2])
    fasta = sys.argv[3]
    outfilename = sys.argv[4]

    prefix = ''
    if '-prefix' in sys.argv:
        prefix = sys.argv[sys.argv.index('-prefix') + 1]

    doSS = False
    if '-splitspace' in sys.argv:
        doSS = True

    if doRandom:
        seqIDs = []
        inputdatafile = open(fasta)
        for line in inputdatafile:
            if line[0]=='>':
                ID = line.strip().split('>')[1]
                if doSS:
                    ID = ID.split(' ')[0]
                seqIDs.append(ID)
        numIDs = int(RF*len(seqIDs))
        SubSampleIDs = random.sample(seqIDs,numIDs)
        WantedDict = {}
        for ID in SubSampleIDs:
            WantedDict[ID] = 1
    else:
        WantedDict = {}
        linelist = open(wanted)
        for line in linelist:
            fields = line.strip().split('\t')
            WantedDict[fields[fieldID]] = 1

    outfile = open(outfilename, 'w')
    
    inputdatafile = open(fasta)
    Keep = False
    for line in inputdatafile:
        if line[0]=='>':
            ID = line.strip().split('>')[1]
            if doSS:
                ID = ID.split(' ')[0]
            if WantedDict.has_key(ID):
                Keep = True
            else:
                Keep = False
        else:
            pass
        if Keep:
            if line[0]=='>':
                outline = '>' + prefix + line[1:]
                outfile.write(outline)
            else:
                outfile.write(line)   

    outfile.close()
   
run()
