##################################
#                                #
# Last modified 2023/04/12       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import gzip
from sets import Set
import Levenshtein
import numpy as np

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s fastq.gz sgRNA_master_file sgRNA_fieldID sgRNA_label_fieldID sgRNAlen sgRNApos for|rev outfilename [-UMIedit N] [-sgRNAedit N]' % sys.argv[0]
        print '\t the script assumes that cell barcodes have already been annotated with SHARE-seq-barcode-annotate.py or SHARE-seq-barcode-annotate-UG.py'
        print '\t the default [-sgRNAedit] edit distance value is 1'
        print '\t the default [-UMIedit] edit distance value is 1'
        print '\t the script expects a file with headers that look like this: @020684_2-UGAv3-67-1721595185:::[AAGACGGA+CGCTGATC+AACTCACC+GGTCTCTCCG]'
        print '\t \t with the last entry being the UMI'
        print '\t use - for stdin for the reads'
        print '\t Note: the script will not work with multiple sgRNAs that share the first sgRNAlen positions, as it will trim all of the sequences in the master list down to that length'
        sys.exit(1)

    sgRNAedit = 1
    if '-sgRNAedit' in sys.argv:
        sgRNAedit = int(sys.argv[sys.argv.index('-sgRNAedit') + 1])
        print 'will used a sgRNA edit distance of', sgRNAedit

    UMIedit = 1
    if '-UMIedit' in sys.argv:
        UMIedit = int(sys.argv[sys.argv.index('-UMIedit') + 1])
        print 'will used a UMI edit distance of', UMIedit

    fastq = sys.argv[1]
    sgRNAs = sys.argv[2]
    sgRNAfieldID = int(sys.argv[3])
    sgRNAlabelfieldID = int(sys.argv[4])
    sgRNAlen = int(sys.argv[5])
    sgRNApos = int(sys.argv[6])
    sgRNAstrand = sys.argv[7]
    outfilename = sys.argv[8]

    sgRNADict = {}

    lineslist = open(sgRNAs)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        sgRNA = fields[sgRNAfieldID][0:sgRNAlen]
        label = fields[sgRNAlabelfieldID]
        sgRNADict[sgRNA] = label

    BCDict = {}

    if fastq == '-':
        lineslist = sys.stdin 
    else:
        lineslist = gzip.open(fastq) 
    RL = 0
    for line in lineslist:
        if RL % 1000000 == 0:
            print str(RL/1000000.) + 'M reads processed'
        if RL % 4 == 0:
            readID = line.strip()
            RL += 1
            continue
        elif RL % 4 == 1:
            sequence = line.strip()
            RL += 1
            continue
        elif RL % 4 == 2:
            RL += 1
            continue
        elif RL % 4 == 3:
            RL += 1
            QC = line.strip()
            pass
        BC1 = readID.split('[')[1].split('+')[0]
        BC2 = readID.split('[')[1].split('+')[1]
        BC3 = readID.split('[')[1].split('+')[2].split(']')[0]
        UMIseq = readID.split('[')[1].split('+')[3].split(']')[0]

        if BC1 == 'nan':
            continue
        if BC2 == 'nan':
            continue
        if BC3 == 'nan':
            continue

        BC = (BC1,BC2,BC3)

        sgRNAseq = sequence[sgRNApos:sgRNApos + sgRNAlen]
        if sgRNAstrand == 'rev':
           sgRNAseq = getReverseComplement(sgRNAseq)

        if UMIseq.count('N') > 1:
            continue

        if sgRNADict.has_key(sgRNAseq):
            sgRNA = sgRNAseq
        elif sgRNAedit > 0:
            EDist = sgRNAedit + 1
            Nearest = []
            for sgRNA in sgRNADict.keys():
                LDist = Levenshtein.distance(sgRNAseq,sgRNA)
                if LDist <= sgRNAedit: 
                    if LDist < EDist:
                        EDist = LDist
                        Nearest = [sgRNA]
                    if LDist == EDist:
                        if sgRNA not in Nearest:
                            Nearest.append(sgRNA)
            if len(Nearest) == 0:
                sgRNA = 'nan'
            elif len(Nearest) == 1:
                sgRNA = Nearest[0]
            else:
                sgRNA = 'ambiguous'
        else:
            sgRNA = 'nan'
        if sgRNA == 'nan':
            continue

        if BCDict.has_key(BC):
            pass
        else:
            BCDict[BC] = {}
        if BCDict[BC].has_key(sgRNA):
            pass
        else:
            BCDict[BC][sgRNA] = {}
        
        if BCDict[BC][sgRNA].has_key(UMIseq):
            BCDict[BC][sgRNA][UMIseq] += 1
        else:
            EDist = UMIedit + 1
            Nearest = []
            for UMI in BCDict[BC][sgRNA].keys():
                LDist = Levenshtein.distance(UMIseq,UMI)
                if LDist <= UMIedit: 
                    Nearest.append(UMI)
                    break
            if len(Nearest) == 0:
                if UMIseq.count('N') == 0:
                    BCDict[BC][sgRNA][UMIseq] = 1
            else:
                continue

    outfile = open(outfilename, 'w')
    outline = '#barcode\tsgRNA\tlabel\tUMIs\treads'
    outfile.write(outline + '\n')

    barcodes = BCDict.keys()
    barcodes.sort()

    print len(barcodes)

    for BC in barcodes:
        (BC1,BC2,BC3) = BC
        for sgRNA in BCDict[BC].keys():
            outline = BC1 + '+' + BC2 + '+' + BC3
            if sgRNA != 'ambiguous':
                outline = outline + '\t' + sgRNA + '\t' + sgRNADict[sgRNA] + '\t' + str(len(BCDict[BC][sgRNA].keys()))
                readcounts = 0
                for UMIseq in BCDict[BC][sgRNA].keys():
                    readcounts += BCDict[BC][sgRNA][UMIseq]
                outline = outline + '\t' + str(readcounts)
            else:
                outline = outline + '\t' + sgRNA + '\t' + 'ambiguous' + '\t' + str(len(BCDict[BC][sgRNA].keys())) + '\tnan'
            if readcounts > 0:
                outfile.write(outline + '\n')
            
run()
