##################################
#                                #
# Last modified 2017/11/29       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import random	
from sets import Set
import time
from multiprocessing import Pool
from threading import Thread
from itertools import izip,islice,tee
import regex

def sub_findre(s,substring,diffnumber):
    sublen = len(substring)
    zip_gen = (izip(substring,islice(s,i,i+sublen)) for i in xrange(len(s)))
    for z in zip_gen:
        l,z = tee(z)
        if sum(1 for i,j in l if i==j) >= sublen-diffnumber:
            new=izip(*z)
            next(new)
            yield ''.join(next(new))


def sub_findre(s,substring,diffnumber):
    sublen = len(substring)
    zip_gen = (izip(substring,islice(s,i,i+sublen)) for i in xrange(len(s)))
    for z in zip_gen:
        l,z = tee(z)
        if sum(1 for i,j in l if i==j) >= sublen-diffnumber:
            new=izip(*z)
            next(new)
            yield ''.join(next(new))

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def select_guides((FDict,TDict,TranscriptDict,TranscriptSeqDict,GenomeDict,N,L,maxMM,do5pG,doUT,doAll,nTS)):

    chrs = []
    for chr in GenomeDict.keys():
        chrs.append((len(GenomeDict[chr]),chr))
    chrs.sort()

    genes = FDict.keys()
    G = 0
    for gene in genes:
        G += 1
        print G, gene
        TDict[gene] = {}
        LMerCountsDict = {}
        for transcript in FDict[gene].keys():
            sequence = TranscriptSeqDict[gene][transcript]
            for pos in range(len(sequence)-L):
                gRNArevcomp = gRNArevcomp = sequence[pos:pos+L]
                if do5pG and gRNArevcomp[-1] != 'C':
                    continue
                if LMerCountsDict.has_key(gRNArevcomp):
                    pass
                else:
                    LMerCountsDict[gRNArevcomp] = 0
                LMerCountsDict[gRNArevcomp] += 1
        IsoPresenceCounts = {}
        for gRNArevcomp in LMerCountsDict.keys():
            IsoPresence = LMerCountsDict[gRNArevcomp]
            if IsoPresenceCounts.has_key(IsoPresence):
                pass
            else:
                IsoPresenceCounts[IsoPresence] = 0
            IsoPresenceCounts[IsoPresence] += 1
        Tnumber = len(FDict[gene].keys())
        IP = Tnumber
        for i in range(Tnumber):
            IP = Tnumber - i
            if IsoPresenceCounts.has_key(IP):
                pass
            else:
                continue
            if IsoPresenceCounts[IP] >= N:
                Tnum = IP
                break
        for gRNArevcomp in LMerCountsDict.keys():
            if doUT:
                if LMerCountsDict[gRNArevcomp] > 1:
                    del LMerCountsDict[gRNArevcomp]
            else:
                if LMerCountsDict[gRNArevcomp] < Tnum:
                    del LMerCountsDict[gRNArevcomp]
        for transcript in FDict[gene].keys():
            TDict[gene][transcript] = []
#            print N, TDict[gene], len(TDict[gene])
#           if len(TDict[gene]) >= N:
#                break
            sequence = FDict[gene][transcript]
            possiblePositions = []
            for i in range(len(sequence)-L):
                gRNArevcom = sequence[i:i+L]
                if LMerCountsDict.has_key(gRNArevcom):
                    possiblePositions.append(i)
            print possiblePositions
            if len(possiblePositions) == 0:
                continue
            if doAll:
                step = 1
            else:
                if len(possiblePositions) < N:
                    continue
                step = len(possiblePositions)/(N+1)
            RN = N
            if doAll:
                RN = len(possiblePositions)
            pos1 = 0
            for i in range(RN):
                NotFound = True
                print gene, transcript, i, step, step*i, len(possiblePositions)
                pos1 = max(pos1,possiblePositions[step*i])
                if pos1 >= max(possiblePositions):
                    break
                j = 0
                while NotFound and pos1 < max(possiblePositions):
                    pos1 += 1
                    gRNArevcomp = sequence[pos1:pos1+L]
                    gRNA = getReverseComplement(gRNArevcomp)
                    Unique = True
                    chrMatches0 = []
                    chrMatchesMM = []
                    for (LLL,chr) in chrs:
                        if '_' not in chr:
                            print chr, 
                        tseq = GenomeDict[chr]
#####                        m0 = regex.compile(gRNArevcomp + '{e<1}')
#####                        matches0 = m0.findall(tseq)

                        mM = regex.search(r'(?:' + gRNArevcomp + '){s<=' + str(maxMM) + '}', tseq)
                        if str(mM) != 'None':
                            chrMatchesMM.append((chr,mM.span()[0]))
                            mM2 = regex.search(r'(?:' + gRNArevcomp + '){s<1}', tseq[mM.span()[0]+1:-1])
                            if str(mM2) != 'None':
                                chrMatchesMM.append((chr,mM.span()[0] + mM2.span()[0]))
                        if len(chrMatchesMM) >= 2:
                            print chrMatchesMM
                            break

                        mM = regex.search(r'(?:' + gRNA + '){s<=' + str(maxMM) + '}', tseq)
                        if str(mM) != 'None':
                            chrMatchesMM.append((chr,mM.span()[0]))
                            mM2 = regex.search(r'(?:' + gRNA + '){s<1}', tseq[mM.span()[0]+1:-1])
                            if str(mM2) != 'None':
                                chrMatchesMM.append((chr,mM.span()[0] + mM2.span()[0]))
                        if len(chrMatchesMM) >= 2:
                            print chrMatchesMM
                            break

                        m0 = regex.search(r'(?:' + gRNArevcomp + '){s<1}', tseq)
                        if str(m0) != 'None':
                            chrMatches0.append((chr,m0.span()[0]))
                            m02 = regex.search(r'(?:' + gRNArevcomp + '){s<1}', tseq[m0.span()[0]+1:-1])
                            if str(m02) != 'None':
                                chrMatches0.append((chr,m0.span()[0] + m02.span()[0]))
                        if len(chrMatches0) >= 2:
                            chrMatchesMM = chrMatches0
                            break

                        m0 = regex.search(r'(?:' + gRNA + '){s<1}', tseq)
                        if str(m0) != 'None':
                            chrMatches0.append((chr,m0.span()[0]))
                            m02 = regex.search(r'(?:' + gRNA + '){s<1}', tseq[m0.span()[0]+1:-1])
                            if str(m02) != 'None':
                                chrMatches0.append((chr,m0.span()[0] + m02.span()[0]))

                    print len(chrMatches0), len(chrMatchesMM)
                    print '.', Unique, '|', 
                    if len(chrMatchesMM) >= 2:
                        Unique = False
#                        break
                    elif len(chrMatchesMM) == 1 and len(chrMatches0) == 0:
                        Unique = False
#                        break
                    elif len(chrMatchesMM) == 1 and len(chrMatches0) == 1:
                        (chr,match) = chrMatches0[0]
                        tseq = GenomeDict[chr]
                        if chr != TranscriptDict[gene][transcript][0][0]:
                            Unique = False
                            print 'F1'
#                            break
                        else:
                            InGene = False
#############
                            matchPos = tseq.find(gRNA)
                            print 'matchPos_gRNA', matchPos
                            for (c,l,r,s) in TranscriptDict[gene][transcript]:
                                if matchPos >= l and matchPos <= r:
                                    if s == '-':
                                        InGene = True
#                                        break
                            print 'IGP1:', InGene
                            matchPos = tseq.find(gRNArevcomp)
                            print 'matchPos_gRNArevcomp', matchPos
                            print 'matchPos_gRNArevcomp', matchPos
                            for (c,l,r,s) in TranscriptDict[gene][transcript]:
                                if matchPos >= l and matchPos <= r:
                                    if s == '+':
                                        InGene = True
#                                        break
                            print 'IGP2:', InGene
                            if not InGene:
                                Unique = False
#                                break
#############
                    print '..', Unique, '||', 
                    if not Unique:
                        print j, pos1, 'jjj'
                        j+=1
                        continue
                    if nTS:
                        pass
                    else:
#                        m0 = regex.compile(gRNArevcomp + '{e<1}')
#                        mM = regex.compile(gRNArevcomp + '{e<' + str(maxMM) + '}')
#                        matches0 = regex.findall(r'(?:' + gRNArevcomp + '){s<1}', tseq)
#                        matchesMM = regex.findall(r'(?:' + gRNArevcomp + '){s<=' + str(maxMM) + '}', tseq)
                        for gID in TranscriptSeqDict.keys():
                            if not Unique:
                                break
                            if gID == gene:
                                for TID in TranscriptSeqDict[gID].keys():
                                    tseq = TranscriptSeqDict[gID][TID]
                                    if len(regex.findall(r'(?:' + gRNArevcomp + '){s<1}', tseq)) > 1:
                                        Unique = False
                                        break
                                    if len(regex.findall(r'(?:' + gRNArevcomp + '){s<=' + str(maxMM) + '}', tseq)) >= 2:
                                        Unique = False
                                        break
                                    if len(regex.findall(r'(?:' + gRNArevcomp + '){s<=' + str(maxMM) + '}', tseq)) == 1 and len(regex.findall(r'(?:' + gRNArevcomp + '){s<1}', tseq)) == 0:
                                        Unique = False
                                        break
                            else:
                                for TID in TranscriptSeqDict[gID].keys():
                                    tseq = TranscriptSeqDict[gID][TID]
                                    if len(regex.findall(r'(?:' + gRNArevcomp + '){s<=' + str(maxMM) + '}', tseq)) >= 1:
                                        Unique = False
                                    break
                    print Unique
                    if Unique:
                        NotFound = False
                    else:
                        j+=1
#                print gRNA, gRNArevcomp
                TDict[gene][transcript].append(gRNA)
                TDict[gene][transcript] = list(Set(TDict[gene][transcript]))

    return TDict

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s list_of_transcript_IDs fieldID gtf genome.fa minMismatchesRequired N_guides_to_pick_per_gene guide_length outfile [-p N_CPUs] [-5pG] [-uniqueToTranscripts] [-all] [-noTranscriptomeScan]' % sys.argv[0]
        print '\tnote: input files can be zipped'
        print '\tnote: the script will start from the middle of the sequence and look for additional guides moving outwards at equal intervals'
        print '\tnote: use the -5pG option to require a G in the 5p end (for U6 expression)'
        print '\tnote: the script will look for sequences unique to the geneID, in addition, it will require that the guide is present in all isoforms'
        print '\t      except when there is sequence common to all isoforms, in which case the possible guides common to the largest number of isoforms'
        print '\t      higher than the requested number of guides will be used'
        print '\tnote: if you want to target individual isoforms instead, use the [-uniqueToTranscripts] option'
        print '\tnote: if you want all valid guides, use the [-all] option'
        sys.exit(1)

    geneIDs = sys.argv[1]
    fieldID = int(sys.argv[2])
    gtf = sys.argv[3]
    fasta = sys.argv[4]
    maxMM = int(sys.argv[5])
    N = int(sys.argv[6])
    L = int(sys.argv[7])
    outfilename = sys.argv[8]

    nTS = False
    if '-noTranscriptomeScan' in sys.argv:
        nTS = True
        print 'will not scan transcriptome'

    NP = 1
    if '-p' in sys.argv:
        NP = int(sys.argv[sys.argv.index('-p') + 1])
        print 'p =', NP

    do5pG = False
    if '-5pG' in sys.argv:
        do5pG = True
        print 'will require a 5p G'

    doAll = False
    if '-all' in sys.argv:
        doAll = True
        print 'will print all valid guides'

    doUT = False
    if '-uniqueToTranscripts' in sys.argv:
        doUT = True
        print 'will select guides unique to each transcript'

    WantedGenes = {}
    if geneIDs.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + geneIDs
    elif geneIDs.endswith('.gz'):
        cmd = 'gunzip -c ' + geneIDs
    else:
        cmd = 'cat ' + geneIDs
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        ID = fields[fieldID]
        WantedGenes[ID] = 1

    print 'finished parsing input'

    GenomeDict = {}
    sequence = ''
    if fasta.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + fasta
    elif fasta.endswith('.gz'):
        cmd = 'gunzip -c ' + fasta
    else:
        cmd = 'cat ' + fasta
    p = os.popen(cmd, "r")
    line = 'line'
    Keep=False
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished parsing genomic fasta file'

    TranscriptDict = {}
    G = 0
    T = 0
    if gtf.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + gtf
    elif gtf.endswith('.gz'):
        cmd = 'gunzip -c ' + gtf
    else:
        cmd = 'cat ' + gtf
    p = os.popen(cmd, "r")
    line = 'line'
    Keep=False
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = fields[8].split('transcript_id "')[1].split('";')[0]
#       print fields[8]
#        print fields[8].split('transcript_id "')
#        print fields[8].split('transcript_id "')[1]
#        print fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if TranscriptDict.has_key(geneID):
            pass
        else:
            TranscriptDict[geneID] = {}
            G += 1
        if TranscriptDict[geneID].has_key(transcript):
            pass
        else:
            TranscriptDict[geneID][transcript] = []
            T += 1
        left = int(fields[3])
        right = int(fields[4])
        orientation = fields[6]
        TranscriptDict[geneID][transcript].append((chr,left,right,orientation))

    print 'finished parsing genomic GTF file'
    print 'found', G, 'genes and', T, 'transcripts'
    print len(WantedGenes.keys()), 'genes targeted'


    TranscriptSeqDict = {}

    for geneID in TranscriptDict.keys():
        TranscriptSeqDict[geneID] = {}
        for transcript in TranscriptDict[geneID].keys():
            sequence=''
            leftEnds=[]
            rightEnds=[] 
            TranscriptDict[geneID][transcript].sort()
            orientation = TranscriptDict[geneID][transcript][0][3]
#            if orientation == '.':
#                if doKeepUundeterminedStrand:
#                    orientation = '+'
#                else:
#                    continue
            if orientation == '+' or orientation == 'F':
                for (chr,left,right,orientation) in TranscriptDict[geneID][transcript]:
                    leftEnds.append(left)
                    rightEnds.append(right)
                    sequence = sequence + GenomeDict[chr][left-1:right]
                sense='plus_strand'
            if orientation == '-' or orientation == 'R':
                for (chr,left,right,orientation) in TranscriptDict[geneID][transcript]:
                    leftEnds.append(left)
                    rightEnds.append(right)
                    sequence = sequence + GenomeDict[chr][left-1:right]
                sense='minus_strand'
                sequence = getReverseComplement(sequence)
            LeftEnd = min(leftEnds)
            RightEnd = max(rightEnds)
            TranscriptSeqDict[geneID][transcript] = sequence
  
    WGenes = WantedGenes.keys()
    k = len(WGenes)/NP

    outfile = open(outfilename, 'w')
    outline = '#geneID\tgeneName\ttranscriptID\tgRNA_ID\tgRNA_sequence\tGC%'
    outfile.write(outline + '\n')

    TDictArray = []
    j=0
    G=0
    for i in range(NP):
        print i
        FDict = {}
        TDict = {}
        if i+1 == NP:
            while j < len(WGenes):
                gene = WGenes[j]
                FDict[gene] = TranscriptSeqDict[gene]
                j += 1
        else:
            while j < k*(i+1):
                gene = WGenes[j]
                FDict[gene] = TranscriptSeqDict[gene]
                j += 1
#        TDictArray.append((FDict,TDict,TranscriptDict,TranscriptSeqDict,GenomeDict,N,L,maxMM,do5pG,doUT,doAll))
        SG = select_guides((FDict,TDict,TranscriptDict,TranscriptSeqDict,GenomeDict,N,L,maxMM,do5pG,doUT,doAll,nTS))
        for gene in SG.keys():
            G += 1
            geneName = TranscriptDict[gene].keys()[0][1]
            for transcript in SG[gene]:
                k = 0
                for gRNA in SG[gene][transcript]:
                    k += 1
                    GC = (gRNA.count('G') + gRNA.count('C'))/(len(gRNA) + 0.0)
                    outline = gene + '\t' + geneName + '\t' + transcript[3] + '\t' + gene + '::' + geneName + '::' + transcript[3] + '::' + str(k) + '\t' + gRNA + '\t' + str(GC)
                    outfile.write(outline + '\n')

#    sys.exit(1)

#    p = Pool(NP)
#    SelectedGuides = p.map(select_guides, TDictArray)
#    p.close()
#    p.join()

    print 'finished selecting guides'
#    print SelectedGuides

#    for SG in SelectedGuides:
#        for gene in SG.keys():
#            k = 0
#            for gRNA in SG[gene]:
#                k += 1
#                GC = (gRNA.count('G') + gRNA.count('C'))/(len(gRNA) + 0.0)
#                outline = gene + '\t' + gene + '_' + str(k) + '\t' + gRNA + '\t' + str(GC)
#                outfile.write(outline + '\n')
#
    outfile.close()

run()
