##################################
#                                #
# Last modified 2018/03/31       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import random	
from sets import Set
import time
from multiprocessing import Pool
from threading import Thread
from itertools import izip,islice,tee
import regex

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s list_of_gene_IDs fieldID gtf genome.fa minMismatchesDistanceRequired N_guides_to_pick_per_gene guide_length bowtie bowtie_genome_index bowtie_transcriptome_index outfile [-p N_CPUs] [-5pG] [-significantBases pos1 pos2] [-uniqueToTranscripts] [-perTranscript] [-all] [-noTranscriptomeScan] [-ATstretchFilter bp] [-noRestrictionSites kmer1(,kmer2,...,kmerN) ]' % sys.argv[0]
        print '\tnote: input files can be zipped'
        print '\tnote: the script will start from the middle of the sequence and look for additional guides moving outwards at equal intervals'
        print '\tnote: use the -5pG option to require a G in the 5p end (for U6 expression)'
        print '\tnote: the script will look for sequences unique to the geneID, in addition, it will require that the guide is present in all isoforms'
        print '\t      except when there is sequence common to all isoforms, in which case the possible guides common to the largest number of isoforms'
        print '\t      higher than the requested number of guides will be used'
        print '\tnote: use the [-significantBases] if you want mismatches and off-targets to be score only over the indicated stretch of the guide. For example, for the 18 positions from 4 to 22, use [-significantBases 4 22]'
        print '\tnote: if you want to target individual isoforms instead, use the [-uniqueToTranscripts] option'
        print '\tnote: if you want all valid guides, use the [-all] option'
        sys.exit(1)

    geneIDs = sys.argv[1]
    fieldID = int(sys.argv[2])
    gtf = sys.argv[3]
    fasta = sys.argv[4]
    maxMM = int(sys.argv[5])
    N = int(sys.argv[6])
    L = int(sys.argv[7])
    bowtie = sys.argv[8]
    bowtieGI = sys.argv[9]
    bowtieTI = sys.argv[10]
    outfilename = sys.argv[11]

    nTS = False
    if '-noTranscriptomeScan' in sys.argv:
        nTS = True
        print 'will not scan transcriptome'

    REkmers = []
    if '-noRestrictionSites' in sys.argv:
        kmers = sys.argv[sys.argv.index('-noRestrictionSites') + 1].split(',')
        for kmer in kmers:
            REkmers.append(kmer.upper())
            REkmers.append(getReverseComplement(kmer).upper())
        print 'will discard all guides containing the following kmers', REkmers

    mAT = L+1
    if '-ATstretchFilter' in sys.argv:
        mAT = int(sys.argv[sys.argv.index('-ATstretchFilter') + 1])
        print 'will discard all guides with stretches of A/T longer than', mAT, 'bases'
        maxA = ''
        maxT = ''
        for i in range(mAT):
            maxA = maxA + 'A'
        for i in range(mAT):
            maxT = maxT + 'T'
        print maxA, maxT

    NP = 1
    if '-p' in sys.argv:
        NP = int(sys.argv[sys.argv.index('-p') + 1])
        print 'p =', NP

    do5pG = False
    if '-5pG' in sys.argv:
        do5pG = True
        print 'will require a 5p G'

    doAll = False
    if '-all' in sys.argv:
        doAll = True
        print 'will print all valid guides'

    doUT = False
    if '-uniqueToTranscripts' in sys.argv:
        doUT = True
        print 'will select guides unique to each transcript'

    doPT = False
    if '-perTranscript' in sys.argv:
        doPT = True
        print 'will select guides for each transcript separately, ignoring others'
        
    if doPT and doUT:
        print 'incompatible options detected, exiting' 
        sys.exit(1)

    doSB = False
    if '-significantBases' in sys.argv:
        doSB = True
        SB1 = int(sys.argv[sys.argv.index('-significantBases') + 1])
        SB2 = int(sys.argv[sys.argv.index('-significantBases') + 2])
        print 'will consider only bases', SB1, 'to', SB2, 'when considering uniqueness and off-target effects'

    WantedGenes = {}
    if geneIDs.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + geneIDs
    elif geneIDs.endswith('.gz'):
        cmd = 'gunzip -c ' + geneIDs
    else:
        cmd = 'cat ' + geneIDs
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        ID = fields[fieldID]
        WantedGenes[ID] = 1

    print 'finished parsing input'

    GenomeDict = {}
    sequence = ''
    if fasta.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + fasta
    elif fasta.endswith('.gz'):
        cmd = 'gunzip -c ' + fasta
    else:
        cmd = 'cat ' + fasta
    p = os.popen(cmd, "r")
    line = 'line'
    Keep=False
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
#            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished parsing genomic fasta file'

    TIDtoTDict = {}

    TranscriptDict = {}
    G = 0
    T = 0
    if gtf.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + gtf
    elif gtf.endswith('.gz'):
        cmd = 'gunzip -c ' + gtf
    else:
        cmd = 'cat ' + gtf
    p = os.popen(cmd, "r")
    line = 'line'
    Keep=False
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        chr = fields[0]
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        TIDtoTDict[transcriptID] = transcript
        if TranscriptDict.has_key(geneID):
            pass
        else:
            TranscriptDict[geneID] = {}
            G += 1
        if TranscriptDict[geneID].has_key(transcript):
            pass
        else:
            TranscriptDict[geneID][transcript] = []
            T += 1
        left = int(fields[3])
        right = int(fields[4])
        orientation = fields[6]
        TranscriptDict[geneID][transcript].append((chr,left,right,orientation))

    print 'finished parsing genomic GTF file'
    print 'found', G, 'genes and', T, 'transcripts'
    print len(WantedGenes.keys()), 'genes targeted'

    TranscriptSeqDict = {}

    for geneID in WantedGenes.keys():
        TranscriptSeqDict[geneID] = {}
        for transcript in TranscriptDict[geneID].keys():
            sequence=''
            leftEnds=[]
            rightEnds=[] 
            TranscriptDict[geneID][transcript].sort()
            orientation = TranscriptDict[geneID][transcript][0][3]
            if orientation == '+' or orientation == 'F':
                for (chr,left,right,orientation) in TranscriptDict[geneID][transcript]:
                    leftEnds.append(left)
                    rightEnds.append(right)
                    sequence = sequence + GenomeDict[chr][left-1:right]
                sense='plus_strand'
            if orientation == '-' or orientation == 'R':
                for (chr,left,right,orientation) in TranscriptDict[geneID][transcript]:
                    leftEnds.append(left)
                    rightEnds.append(right)
                    sequence = sequence + GenomeDict[chr][left-1:right]
                sense='minus_strand'
                sequence = getReverseComplement(sequence)
            LeftEnd = min(leftEnds)
            RightEnd = max(rightEnds)
            TranscriptSeqDict[geneID][transcript] = sequence

    print 'finished parsing transcript sequences'
  
    tempoutfile = open(outfilename + '.temp', 'w')

    WGenes = WantedGenes.keys()
    WGenes.sort()

    G = 0
    for gene in WGenes:
        G += 1
#        print G, gene
        LMerCountsDict = {}
        for transcript in TranscriptSeqDict[gene].keys():
            sequence = TranscriptSeqDict[gene][transcript]
            for pos in range(len(sequence)-L):
                gRNArevcomp = sequence[pos:pos+L]
                if do5pG and gRNArevcomp[-1] != 'C':
                    continue
                if maxA in gRNArevcomp:
                    continue
                if maxT in gRNArevcomp:
                    continue
                HasRE = False
                for REST in REkmers:
                    if REST in gRNArevcomp:
                        HasRE = True
                if HasRE:
                    continue
                if LMerCountsDict.has_key(gRNArevcomp):
                    pass
                else:
                    LMerCountsDict[gRNArevcomp] = 0
                LMerCountsDict[gRNArevcomp] += 1
                if doSB:
                    gRNArevcompSB = gRNArevcomp[-SB2:-SB1]
                    if LMerCountsDict.has_key(gRNArevcompSB):
                        pass
                    else:
                        LMerCountsDict[gRNArevcompSB] = 0
                    LMerCountsDict[gRNArevcompSB] += 1
        if doPT:
            pass
        else:
            IsoPresenceCounts = {}
            for gRNArevcomp in LMerCountsDict.keys():
                IsoPresence = LMerCountsDict[gRNArevcomp]
                if IsoPresenceCounts.has_key(IsoPresence):
                    pass
                else:
                    IsoPresenceCounts[IsoPresence] = 0
                IsoPresenceCounts[IsoPresence] += 1
            Tnumber = len(TranscriptSeqDict[gene].keys())
            IP = Tnumber
            for i in range(Tnumber):
                IP = Tnumber - i
                if IsoPresenceCounts.has_key(IP):
                    pass
                else:
                    continue
                if IsoPresenceCounts[IP] >= N:
                    Tnum = IP
                    break
            if doSB:
                for gRNArevcomp in LMerCountsDict.keys():
                    if len(gRNArevcomp) == L:
                        gRNArevcompSB = gRNArevcomp[-SB2:-SB1]
                    else:
                        continue
                    if doUT:
                        if LMerCountsDict[gRNArevcomp] > 1 or LMerCountsDict[gRNArevcompSB] > 1:
                            del LMerCountsDict[gRNArevcomp]
                    else:
                        if LMerCountsDict[gRNArevcomp] < Tnum or LMerCountsDict[gRNArevcompSB] < Tnum:
                            del LMerCountsDict[gRNArevcomp]
                for gRNArevcomp in LMerCountsDict.keys():
                    if len(gRNArevcomp) == SB2-SB1:
                        del LMerCountsDict[gRNArevcomp]
            else:
                for gRNArevcomp in LMerCountsDict.keys():
                    if doUT:
                        if LMerCountsDict[gRNArevcomp] > 1:
                            del LMerCountsDict[gRNArevcomp]
                    else:
                        if LMerCountsDict[gRNArevcomp] < Tnum:
                            del LMerCountsDict[gRNArevcomp]
        for transcript in TranscriptSeqDict[gene].keys():
            sequence = TranscriptSeqDict[gene][transcript]
            possiblePositions = []
            for i in range(len(sequence)-L):
                gRNArevcom = sequence[i:i+L]
                gRNA = getReverseComplement(gRNArevcom)
                if LMerCountsDict.has_key(gRNArevcom):
                    possiblePositions.append(i)
                    if doSB:
                        outline = '>' + gene + '::' + transcript[3] + '::' + str(i) + '::' + gRNA
                        tempoutfile.write(outline + '\n')
                        tempoutfile.write(gRNA[SB1:SB2] + '\n')
                    else:
                        outline = '>' + gene + '::' + transcript[3] + '::' + str(i)
                        tempoutfile.write(outline + '\n')
                        tempoutfile.write(gRNA + '\n')
    
    print 'finished identifying candidate guides'

    tempoutfile.close()

#    cmd = 'cat ' + outfilename + '.temp' + ' | ' + bowtie + ' ' + bowtieGI + ' -p ' + str(NP) + ' -v ' + str(maxMM) + ' -a -t -f --sam --sam-nh - '
    cmd = 'cat ' + outfilename + '.temp' + ' | ' + bowtie + ' ' + bowtieGI + ' -p ' + str(NP) + ' -v ' + str(maxMM) + ' -k 4 -t -f --sam --sam-nh - '

    CandidatePositions = {}

    p = os.popen(cmd, "r")
    line = 'line'
    Keep=False
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('Time loading'):
            continue
        if line.startswith('@PG\t'):
            continue
        if line.startswith('@HD\t'):
            continue
        if line.startswith('@SQ\t'):
            continue
        fields = line.strip().split('\t')
#        print fields
        ID = fields[0]
        gene = ID.split('::')[0]
        transcript = ID.split('::')[1]
        pos = int(ID.split('::')[2])
        if CandidatePositions.has_key(gene):
            pass
        else:
            CandidatePositions[gene] = {}
        if CandidatePositions[gene].has_key(transcript):
            pass
        else:
            CandidatePositions[gene][transcript] = {}
        if CandidatePositions[gene][transcript].has_key(pos):
            pass
        else:
            CandidatePositions[gene][transcript][pos] = 0
        CandidatePositions[gene][transcript][pos] += 1
        if CandidatePositions[gene][transcript][pos] > 1:
            pass
        else:
            chr = fields[2]
            if chr == '*':
                continue
            T = TIDtoTDict[transcript]
            c = TranscriptDict[gene][T][0][0]
            if chr != c:
                CandidatePositions[gene][transcript][pos] += 1
            else:
                mappos = int(fields[3]) 
                if 16 in FLAG(int(fields[1])):
                    strand = '-'
                else:
                    strand = '+'
                s = TranscriptDict[gene][T][0][3]
                if s == strand:
                    CandidatePositions[gene][transcript][pos] += 1
                else:
                    InTranscript = False
                    for (c,l,r,s) in TranscriptDict[gene][T]:
                        if mappos >= l and mappos <= r:
                            InTranscript = True
                            break
                    if not InTranscript:
                        CandidatePositions[gene][transcript][pos] += 1

    print 'finished scanning genome'

    for gene in CandidatePositions.keys():
        for transcript in CandidatePositions[gene].keys():
            for pos in CandidatePositions[gene][transcript].keys():
                if CandidatePositions[gene][transcript][pos] > 1:
                    del CandidatePositions[gene][transcript][pos]

    print 'finished filtering postions based on genome mapping'
    
    if nTS:
        pass
    else:
        cmd = 'cat ' + outfilename + '.temp' + ' | ' + bowtie + ' ' + bowtieTI + ' -p ' + str(NP) + ' -v ' + str(maxMM) + ' -a -t -f --sam --sam-nh - '
        p = os.popen(cmd, "r")
        line = 'line'
        Keep=False
        while line != '':
            line = p.readline()
            if line == '':
                break
            if line.startswith('Time loading'):
                continue
            if line.startswith('@PG\t'):
                continue
            if line.startswith('@HD\t'):
                continue
            if line.startswith('@SQ\t'):
                continue
            fields = line.strip().split('\t')
            ID = fields[0]
            gene = ID.split('::')[0]
            transcript = ID.split('::')[1]
            pos = int(ID.split('::')[2])
            if CandidatePositions[gene][transcript].has_key(pos):
                pass
            else:
                continue
            target = fields[2]
            if target == '*':
                continue
            if ':' in fields[2]:
                targetGene = fields[2].split(':')[1]
                targetTranscript = fields[2].split(':')[3]
            else:
                targetGene = fields[2]
                targetTranscript = fields[2]
            if targetGene != gene:
                del CandidatePositions[gene][transcript][pos]
                continue
            else:
                if 16 in FLAG(int(fields[1])):
                    strand = '-'
                else:
                    strand = '+'
                    del CandidatePositions[gene][transcript][pos]

        print 'finished mapping and filtering postions against the transcriptome'

    cmd = 'rm ' + outfilename + '.temp'
    os.system(cmd)

    tempoutfile.close()

    outfile = open(outfilename, 'w')
    outline = '#geneID\tgeneName\ttranscriptID\tgRNA_ID\tgRNA_sequence\tGC%'
    outfile.write(outline + '\n')

    for gene in CandidatePositions.keys():
        geneName = TranscriptDict[gene].keys()[0][1]
        for transcript in CandidatePositions[gene].keys():
            possiblePositions = CandidatePositions[gene][transcript].keys()
            possiblePositions.sort()
            if len(possiblePositions) == 0:
                continue
            if doAll:
                step = 1
            else:
                if len(possiblePositions) < N:
                    step = 1
                step = len(possiblePositions)/(N+1)
            RN = N
            if doAll:
                RN = len(possiblePositions)
            k = 0
            pos1 = 0
            for i in range(RN):
                pos1 = possiblePositions[step*i]
                gRNArevcomp = TranscriptSeqDict[gene][TIDtoTDict[transcript]][pos1:pos1+L]
                gRNA = getReverseComplement(gRNArevcomp)
                k += 1
                GC = (gRNA.count('G') + gRNA.count('C'))/(len(gRNA) + 0.0)
                outline = gene + '\t' + geneName + '\t' + transcript + '\t' + gene + '::' + geneName + '::' + transcript + '::' + str(k) + '\t' + gRNA + '\t' + str(GC)
                outfile.write(outline + '\n')

    print 'finished selecting guides'

    outfile.close()

run()
