##################################
#                                #
# Last modified 2018/11/27       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import gzip
import time
import Levenshtein
# from multiprocessing import Pool
# from threading import Thread

# def simulate_converted_transcripts((FPKMGeneDict,TranscriptDict,CR1,CR2,TotalFPKM,Ncells,NtranscriptsPerCell)):
#    return TranscriptDict

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
#        print 'usage: python %s fasta inputfilename N_mismatches outputfilename [-p threads]' % sys.argv[0]
        print 'usage: python %s fasta inputfilename N_mismatches outputfilename' % sys.argv[0]
        print '\tNote: guides have to be the same length'
        sys.exit(1)

    fasta = sys.argv[1]
    guides = sys.argv[2]
    MM = int(sys.argv[3])
    outfilename = sys.argv[4]

#    NP = 1
#    if '-p' in sys.argv:
#        NP = int(sys.argv[sys.argv.index('-p') + 1])

    GuideList = []

    K = 0

    if guides.endswith('.gz'):
        linelist = gzip.open(guides)
    else:
        linelist = open(guides)
    i=0
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        sgRNA = line.strip().split('\t')[0].upper()
        GuideList.append(sgRNA)
        if i == 0:
            K = len(sgRNA)
        else:
            if len(sgRNA) != K:
                print 'guides of different length detected, exiting'
                sys.exit(1)
        i+=1

    print 'finished parsing guides'

    GenomeDict={}
    sequence=''
    if fasta.endswith('.gz'):
        inputdatafile = gzip.open(fasta)
    else:
        inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished parsing genome'

    KmerDict = {}
    C = 0
    CDict = {}
    for chr in GenomeDict.keys():
        C += 1
        CDict[C] = chr
        for i in range(len(GenomeDict[chr]) - K):
            if i % 1000000 == 0:
                print 'kmer parsing', chr, str(i/1000000) + 'M positions'
            kmer = GenomeDict[chr][i:i+K]
            if KmerDict.has_key(kmer):
                pass
            else:
                KmerDict[kmer] = {}
            if KmerDict[kmer].has_key(C):
                pass
            else:
                KmerDict[kmer][C] = []
            KmerDict[kmer][C].append(i)

    print 'finished parsing kmers'

    outfile = open(outfilename, 'w')
    outline = '#sgRNA\tgenomic_sequence\tmismatches\tchr\tpos\tstrand'
    outfile.write(outline + '\n')

#    p = Pool(NP)
#    SequencedTranscriptDicts = p.map(simulate_read_counts, TDictArray)

    PSGR = 0

    start = time.time()
    for sgRNA in GuideList:
        KK = 0
        PSGR += 1
        if PSGR % 100 == 0:
            print time.time() - start
            start = time.time()
            print PSGR, 'sgRNAs processed'
        RCsgRNA = getReverseComplement(sgRNA)
        for kmer in KmerDict:
            KK += 1
            if KK % 1000000 == 0:
                print sgRNA, KK, 'kmers processed for sgRNA', sgRNA
#            D = Levenshtein.distance(kmer,sgRNA)
            H = Levenshtein.hamming(kmer,sgRNA)
#            if D <= MM:
            if H <= MM:
                for C in KmerDict[kmer].keys():
                    for i in KmerDict[kmer][C]:
                        outline = sgRNA + '\t' + kmer + '\t' + str(H) + '\t' + CDict[C] + '\t' + str(i) + '\t' + '+'
                        outfile.write(outline + '\n')
            H = Levenshtein.hamming(kmer,RCsgRNA)
#            if D <= MM:
            if H <= MM:
                for C in KmerDict[kmer].keys():
                    for i in KmerDict[kmer][C]:
                        outline = sgRNA + '\t' + getReverseComplement(kmer) + '\t' + str(H) + '\t' + CDict[C] + '\t' + str(i) + '\t' + '-'
                        outfile.write(outline + '\n')

    outfile.close()
   
run()
