##################################
#                                #
# Last modified 2023/07/15       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import gzip
from sets import Set
import Levenshtein
import numpy as np

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s BC1file fieldID BC2file fieldID2 BC3file fieldID3 fastq.gz [-BC1 string] [-BCedit N] [-revcompBC] [-keepUMI] [-addUMI pos UMIlen]' % sys.argv[0]
        print '\t the script will print out std out by default'
        print '\t it is expected that the input is the output of the UGBAMtoFASTQ.py/UGBAMtoFASTQ-V2.py scripts, i.e. headers that look like this:'
        print '\t @020684_1-UGAv3-174-0988986986:::[ACAGATTC+AATCCGTC+TGGCTTCA+GTATGGGCCC]'
        sys.exit(1)

    BC1file = sys.argv[1]
    fieldID1 = int(sys.argv[2])
    BC2file = sys.argv[3]
    fieldID2 = int(sys.argv[4])
    BC3file = sys.argv[5]
    fieldID3 = int(sys.argv[6])
    fastq = sys.argv[7]

    BCedit = 1
    if '-BCedit' in sys.argv:
        BCedit = int(sys.argv[sys.argv.index('-BCedit') + 1])
#        print 'will used a barcoded edit distance of', BCedit

    doRevComp = False
    if '-revcompBC' in sys.argv:
        doRevComp = True
#        print 'will use reverse complemented barcodes'

    doBC1string = False
    if '-BC1' in sys.argv:
        doBC1string = True
        BC1string = sys.argv[sys.argv.index('-BC1') + 1]
#        print 'will use', BC1string, ' as the BC1 string instead of looking for it in the sequence'

    doKeepUMI = False
    if '-keepUMI' in sys.argv:
        doKeepUMI = True

    doAddUMI = False
    if '-addUMI' in sys.argv:
        doAddUMI = True
        UMIpos = int(sys.argv[sys.argv.index('-addUMI') + 1])
        UMIlen = int(sys.argv[sys.argv.index('-addUMI') + 2])

    BCDict = {}
    BCDict[1] = {}
    BCDict[2] = {}
    BCDict[3] = {}

    lineslist = open(BC1file)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        BC = fields[fieldID1]
        if doRevComp:
            BC = getReverseComplement(BC)
        BCDict[1][BC] = 1
    if doBC1string:
        BCDict[1][BC1string] = 0

    lineslist = open(BC2file)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        BC = fields[fieldID1]
        if doRevComp:
            BC = getReverseComplement(BC)
        BCDict[2][BC] = 1

    lineslist = open(BC3file)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        BC = fields[fieldID1]
        if doRevComp:
            BC = getReverseComplement(BC)
        BCDict[3][BC] = 1

    RL = 0
    lineslist = gzip.open(fastq)
    for line in lineslist:
        if RL % 4 == 0:
            readID = line.strip()
            RL += 1
            continue
        elif RL % 4 == 1:
            sequence = line.strip()
            RL += 1
            continue
        elif RL % 4 == 2:
            RL += 1
            continue
        elif RL % 4 == 3:
            RL += 1
            QC = line.strip()
            pass
#        j+=1
#        if j % 1000000 == 0:
#            print j, 'lines processed'
        if doBC1string:
            BC1seq = BC1string
        else:
            BC1seq = readID.split('[')[1].split('+')[0]
        BC2seq = readID.split('[')[1].split('+')[1]
        BC3seq = readID.split('[')[1].split('+')[2].split(']')[0]
        if doKeepUMI:
            UMI = readID.split('[')[1].split('+')[3].split(']')[0]
        if doAddUMI:
            UMI = sequence[UMIpos:UMIpos+UMIlen]
            sequence = sequence[UMIpos+UMIlen:]
            QC = QC[UMIpos+UMIlen:]

        if BCDict[1].has_key(BC1seq):
            BC1 = BC1seq
        else:
            EDist = len(BC1seq)
            NearestRTIdx = []
            for BCindex in BCDict[1].keys():
                LDist = Levenshtein.distance(BC1seq,BCindex)
                if LDist <= BCedit: 
                    if LDist < EDist:
                        EDist = LDist
                        NearestRTIdx = [BCindex]
                    if LDist == EDist:
                        NearestRTIdx.append(BCindex)
            if len(NearestRTIdx) == 0:
                BC1 = 'nan'
            elif len(NearestRTIdx) == 1:
                BC1 = NearestRTIdx[0]
            else:
                BC1 = 'nan'

        if BCDict[2].has_key(BC2seq):
            BC2 = BC2seq
        else:
            EDist = len(BC2seq)
            NearestRTIdx = []
            for BCindex in BCDict[2].keys():
                LDist = Levenshtein.distance(BC2seq,BCindex)
                if LDist <= BCedit: 
                    if LDist < EDist:
                        EDist = LDist
                        NearestRTIdx = [BCindex]
                    if LDist == EDist:
                        NearestRTIdx.append(BCindex)
            if len(NearestRTIdx) == 0:
                BC2 = 'nan'
            elif len(NearestRTIdx) == 1:
                BC2 = NearestRTIdx[0]
            else:
                BC2 = 'nan'

        if BCDict[3].has_key(BC3seq):
            BC3 = BC3seq
        else:
            EDist = len(BC3seq)
            NearestRTIdx = []
            for BCindex in BCDict[3].keys():
                LDist = Levenshtein.distance(BC3seq,BCindex)
                if LDist <= BCedit: 
                    if LDist < EDist:
                        EDist = LDist
                        NearestRTIdx = [BCindex]
                    if LDist == EDist:
                        NearestRTIdx.append(BCindex)
            if len(NearestRTIdx) == 0:
                BC3 = 'nan'
            elif len(NearestRTIdx) == 1:
                BC3 = NearestRTIdx[0]
            else:
                BC3 = 'nan'

        if doKeepUMI or doAddUMI:
            newbarcode7 = '[' + BC1 + '+' + BC2 + '+' + BC3 + '+' + UMI + ']'
        else:
            newbarcode7 = '[' + BC1 + '+' + BC2 + '+' + BC3 + ']'
        newID = readID.split(':::')[0] + ':::' + newbarcode7
        print newID
        print sequence
        print '+'
        print QC
            
run()
