##################################
#                                #
# Last modified 2017/12/14       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import Levenshtein
import sys
import string
import math
from sets import Set
import time
import itertools

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s input IDFiedlID seqFiedlID wanted_guides fieldID numberMisMatches deletionSize InsertionSize outfile [-combinatorial]' % sys.argv[0]
        print '\tNote: possible number of indels: 0 or 1'
        print '\tNote: the script will output all possible mismatches tiling the guides; use the [-combinatorial] option if you want all possible guides'
        print '\tNote: the script will output all mismatches and indels UP TO the specified length (i.e. not just of that length)'
        print '\tNote: right now only a single indel is allowed'
        print '\tNote: for the [-combinatorial] option, only second-order combinatorics is allowed now, i.e. even number of mismatches, and pairwise positions'
        sys.exit(1)

    doCombinatorial = False
    if '-combinatorial' in sys.argv:
        doCombinatorial = True
        print 'will output combinatorial pairwise mismatches'

    input = sys.argv[1]
    IDfieldID = int(sys.argv[2])
    seqfieldID = int(sys.argv[3])
    wanted = sys.argv[4]
    fieldID = int(sys.argv[5])
    nMM = int(sys.argv[6])
    lenDel = int(sys.argv[7])
    lenIns = int(sys.argv[8])
    outfilename = sys.argv[9]

    WantedDict = {}
    linelist = open(wanted)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        ID = fields[fieldID]        
        WantedDict[ID] = 1

    outfile = open(outfilename, 'w')

    acgt = ['A','G','C','T']
 
    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            outline = line.strip()
            outfile.write(outline + '\tmismatchedGuide\tpos\tmismatches\tdeletion\tinsertion\n') 
            continue
        fields = line.strip().split('\t')
        ID = fields[IDfieldID]
        if WantedDict.has_key(ID):
            pass
        else:
            continue
        sequence = fields[seqfieldID]
        if doCombinatorial:
            if nMM % 2 != 0:
                print 'warning: only even number of mismatches allowed with the [-combinatorial] option'
            for M in range(1,nMM/2+1):
                mmseqlist = list(itertools.combinations_with_replacement(acgt, M))
                for i in range(len(sequence)-2*M):
                    prefix = sequence[0:i]
                    seq1 = sequence[i:i+M]
                    for mm1 in mmseqlist:
                        mmseq1 = ''.join(mm1)
                        if Levenshtein.distance(seq1,mmseq1) < M:
                            continue
                        for j in range(i+M+1,len(sequence)-M):
                            middle = sequence[i+M:j]
                            seq2 = sequence[j:j+M]
                            suffix = sequence[j+M:]
                            for mm2 in mmseqlist:
                                mmseq2 = ''.join(mm2)
                                if Levenshtein.distance(seq2,mmseq2) < M:
                                    continue
                                newseq = prefix + mmseq1 + middle + mmseq2 + suffix
                                outline = line.strip() + '\t' + newseq + '\t' + str(i+1) + ',' + str(j+1) + '\t' + str(M) + '\t' + str(0) + '\t' + str(0)
                                outfile.write(outline + '\n')
        else:
            for M in range(1,nMM+1):
                mmseqlist = list(itertools.combinations_with_replacement(acgt, M))
                for i in range(len(sequence)-M):
                    prefix = sequence[0:i]
                    suffix = sequence[i+M:]
                    seq = sequence[i:i+M]
                    for mm in mmseqlist:
                        mmseq = ''.join(mm)
                        if Levenshtein.distance(seq,mmseq) < M:
                            continue
                        newseq = prefix + mmseq + suffix
                        outline = line.strip() + '\t' + newseq + '\t' + str(i+1) + '\t' + str(M) + '\t' + str(0) + '\t' + str(0)
                        outfile.write(outline + '\n')
        for Del in range(1,lenDel+1):
            for i in range(len(sequence)-Del):
                prefix = sequence[0:i]
                suffix = sequence[i+Del:]
                newseq = prefix + suffix
                outline = line.strip() + '\t' + newseq + '\t' + str(i+1) + '\t' + str(0) + '\t' + str(Del) + '\t' + str(0)
                outfile.write(outline + '\n')
        for Ins in range(1,lenIns+1):
            for i in range(len(sequence)-Ins):
                prefix = sequence[0:i]
                suffix = sequence[i:]
                mmseqlist = list(itertools.combinations_with_replacement(acgt, Ins))
                for mm in mmseqlist:
                    mmseq = ''.join(mm)
                    newseq = prefix + mmseq + suffix
                    outline = line.strip() + '\t' + newseq + '\t' + str(i+1) + '\t' + str(0) + '\t' + str(0) + '\t' + str(Ins)
                    outfile.write(outline + '\n')

    outfile.close()

run()
