##################################
#                                #
# Last modified 2022/03/21       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import gzip
from sets import Set
import Levenshtein
import numpy as np

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s BC1file fieldID BClen fastq.gz outprefix [-BCedit]' % sys.argv[0]
        print '\t the script will assumes the FASTQ fiels start with the barcode and it is in the forward oritentation'
        print '\t the script assumes variable-length barcodes, but that barcodes are all distinct within their first BClen bases'
        sys.exit(1)

    BC1file = sys.argv[1]
    fieldID1 = int(sys.argv[2])
    BCLen = int(sys.argv[3])
    fastq = sys.argv[4]
    outprefix = sys.argv[5]

    BCedit = 1
    if '-BCedit' in sys.argv:
        BCedit = int(sys.argv[sys.argv.index('-BCedit') + 1])

    BCDict = {}

    outfileDict = {}
    lineslist = open(BC1file)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        BC = fields[fieldID1]
        BCDict[BC[0:BCLen]] = BC
        outfileDict[BC] = open(outprefix + '.' + BC + '.fastq', 'w')

    RL = 0
    j = 0
    lineslist = gzip.open(fastq)
    for line in lineslist:
        if RL % 4 == 0:
            readID = line.strip()
            RL += 1
            continue
        elif RL % 4 == 1:
            sequence = line.strip()
            RL += 1
            continue
        elif RL % 4 == 2:
            RL += 1
            continue
        elif RL % 4 == 3:
            RL += 1
            QC = line.strip()
            pass
        j+=1
        if j % 1000000 == 0:
            print j/1000000, 'M lines processed'

        BC1seq = sequence[0:BCLen]

        BC1 = 'nan'

        if BCDict.has_key(BC1seq):
            BC1 = BC1seq
        else:
            EDist = BCLen
            NearestRTIdx = []
            for BCindex in BCDict.keys():
                LDist = Levenshtein.distance(BC1seq,BCindex)
                if LDist <= BCedit: 
                    if LDist < EDist:
                        EDist = LDist
                        NearestRTIdx = [BCindex]
                    if LDist == EDist:
                        NearestRTIdx.append(BCindex)
            if len(NearestRTIdx) == 0:
                BC1 = 'nan'
            elif len(NearestRTIdx) == 1:
                BC1 = NearestRTIdx[0]
            else:
                BC1 = 'nan'
        if BC1 == 'nan':
            continue

        BC = BCDict[BC1]

        outfileDict[BC].write(readID + '\n')
        outfileDict[BC].write(sequence[len(BC)-1:] + '\n')
        outfileDict[BC].write('+' + '\n')
        outfileDict[BC].write(QC[len(BC)-1:] + '\n')
            
run()
