##################################
#                                #
# Last modified 2022/04/05       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import Levenshtein
import gzip

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s <inputfilename> config maxLevenshtein [-70xOnly] [-50xOnly] [-suffix string] [-prefix string]' % sys.argv[0]
        print '\tConfig_format: <label> <barcode1 (70x)> <barcode2 (50x)>'
        print '\tthe script assumes barcodes are in the read ID as follows:'
        print '\t\t@M00653:10:000000000-C2CK2:1:1102:11421:1321 1:N:0:NCCACTGTATCTCG+NGAGTAGA'
        print '\tthe script will only use the first K bases of the index reads where K is the size of the barcode specified in the config file'
        print '\tthe script assumes the input is the output of PEFastqToTabDelimited.py'
        sys.exit(1)

    inputfilename = sys.argv[1]
    maxLev = int(sys.argv[3])

    PRE = ''
    if '-prefix' in sys.argv:
        PRE = sys.argv[sys.argv.index('-prefix') + 1]

    SUF = ''
    if '-suffix' in sys.argv:
        SUF = sys.argv[sys.argv.index('-suffix') + 1]

    print 'SUF:', SUF 
    print 'PRE', 'PRE'

    do70xOnly = False
    if '-70xOnly' in sys.argv:
        do70xOnly = True
        print 'will only use 70x barcodes'

    do50xOnly = False
    if '-50xOnly' in sys.argv:
        do50xOnly = True
        print 'will only use 50x barcodes'

    if do50xOnly and do70xOnly:
        print 'conflicting options selected, do70xOnly and do50xOnly, exiting'
        sys.exit(1)

    config = sys.argv[2]
    BC1s = {}
    BC2s = {}
    BC1Dict = {}
    BC2Dict = {}
    linelist = open(config)
    for line in linelist:
        fields = line.strip().split('\t')
        BC1 = fields[1]
        BC2 = fields[2]
        if do70xOnly:
            BC2 = len(BC1)*'N'
        if do50xOnly:
            BC1 = len(BC2)*'N'
        label = fields[0]
        BC1s[BC1] = 1
        BC2s[BC2] = 1
        BC1Dict[(BC1,BC2)] = gzip.open(PRE + label + SUF + '.end1.fastq.gz', 'w')
        BC2Dict[(BC1,BC2)] = gzip.open(PRE + label + SUF +  '.end2.fastq.gz', 'w')

    BC1Dict['undetermined'] = gzip.open(PRE + 'Undetermined' + SUF + '.end1.fastq.gz', 'w')
    BC2Dict['undetermined'] = gzip.open(PRE + 'Undetermined' + SUF + '.end2.fastq.gz', 'w')

    K1 = len(BC1)
    K2 = len(BC2)

    minBetweenBarcodesLev1 = K1
    minBetweenBarcodesLev2 = K2
    for BC1_1 in BC1s.keys():
        for BC1_2 in BC1s.keys():
            D = Levenshtein.distance(BC1_1,BC1_2)
            if D != 0 and D < minBetweenBarcodesLev1:
                minBetweenBarcodesLev1 = D
    for BC2_1 in BC2s.keys():
        for BC2_2 in BC1s.keys():
            D = Levenshtein.distance(BC2_1,BC2_2)
            if D != 0 and D < minBetweenBarcodesLev2:
                minBetweenBarcodesLev2 = D

#    print minBetweenBarcodesLev1, minBetweenBarcodesLev2

    doStdIn = False
    if inputfilename != '-':
        if inputfilename.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + inputfilename
        elif inputfilename.endswith('.gz') or inputfilename.endswith('.bgz'):
            cmd = 'zcat ' + inputfilename
        else:
            cmd = 'cat ' + inputfilename
        p = os.popen(cmd, "r")
    else:
        doStdIn = True

    line = 'line'

    k = 0
    i = 1
    while line != '':
        if doStdIn:
            line = sys.stdin.readline()
        else:
            line = p.readline()
        if line == '':
            break
        k += 1
        if k % 100000 == 0:
            print str(k/1000000.) + 'M reads processed'
        fields = line.strip().split('\t')
        ID = fields[0]
        BC1 = ID.split(' ')[-1].split(':')[-1].split('+')[0][0:K1]
        BC2 = ID.split(' ')[-1].split(':')[-1].split('+')[1][0:K2]
        if do70xOnly:
            BC2 = len(BC1)*'N'
        if do50xOnly:
            BC1 = len(BC2)*'N'
        seq1 = fields[1]
        scores1 = fields[2]
        seq2 = fields[3]
        scores2 = fields[4]
        if BC1s.has_key(BC1):
            pass
        else:
            candidate_BCs = []
            if minBetweenBarcodesLev1 > maxLev:
                for refBC1 in BC1s.keys():
                    if Levenshtein.distance(BC1,refBC1) <= maxLev:
                        candidate_BCs.append(refBC1)
                        break
            else:
                for refBC1 in BC1s.keys():
                    if Levenshtein.distance(BC1,refBC1) <= maxLev:
                        candidate_BCs.append(refBC1)
            if len(candidate_BCs) == 1:
                BC1 = candidate_BCs[0]
            else:
                BC1Dict['undetermined'].write('@' + ID + '\n')
                BC1Dict['undetermined'].write(seq1 + '\n')
                BC1Dict['undetermined'].write('+' + '\n')
                BC1Dict['undetermined'].write(scores1 + '\n')
                BC2Dict['undetermined'].write('@' + ID.replace(' 1:N:0:',' 2:N:0:') + '\n')
                BC2Dict['undetermined'].write(seq2 + '\n')
                BC2Dict['undetermined'].write('+' + '\n')
                BC2Dict['undetermined'].write(scores2 + '\n')
                continue
        if BC2s.has_key(BC2):
            pass
        else:
            candidate_BCs = []
            if minBetweenBarcodesLev1 > maxLev:
                for refBC2 in BC2s.keys():
                    if Levenshtein.distance(BC2,refBC2) <= maxLev:
                        candidate_BCs.append(refBC2)
                        break
            else:
                for refBC2 in BC2s.keys():
                    if Levenshtein.distance(BC2,refBC2) <= maxLev:
                        candidate_BCs.append(refBC2)
            if len(candidate_BCs) == 1:
                BC2 = candidate_BCs[0]
            else:
                BC1Dict['undetermined'].write('@' + ID + '\n')
                BC1Dict['undetermined'].write(seq1 + '\n')
                BC1Dict['undetermined'].write('+' + '\n')
                BC1Dict['undetermined'].write(scores1 + '\n')
                BC2Dict['undetermined'].write('@' + ID.replace(' 1:N:0:',' 2:N:0:') + '\n')
                BC2Dict['undetermined'].write(seq2 + '\n')
                BC2Dict['undetermined'].write('+' + '\n')
                BC2Dict['undetermined'].write(scores2 + '\n')
                continue
        BC = (BC1,BC2)
        if BC2Dict.has_key(BC):
            BC1Dict[BC].write('@' + ID + '\n')
            BC1Dict[BC].write(seq1 + '\n')
            BC1Dict[BC].write('+' + '\n')
            BC1Dict[BC].write(scores1 + '\n')
            BC2Dict[BC].write('@' + ID.replace(' 1:N:0:',' 2:N:0:') + '\n')
            BC2Dict[BC].write(seq2 + '\n')
            BC2Dict[BC].write('+' + '\n')
            BC2Dict[BC].write(scores2 + '\n')
        else:
            BC1Dict['undetermined'].write('@' + ID + '\n')
            BC1Dict['undetermined'].write(seq1 + '\n')
            BC1Dict['undetermined'].write('+' + '\n')
            BC1Dict['undetermined'].write(scores1 + '\n')
            BC2Dict['undetermined'].write('@' + ID.replace(' 1:N:0:',' 2:N:0:') + '\n')
            BC2Dict['undetermined'].write(seq2 + '\n')
            BC2Dict['undetermined'].write('+' + '\n')
            BC2Dict['undetermined'].write(scores2 + '\n')

    for BC in BC1Dict.keys():
        BC1Dict[BC].close()
    for BC in BC2Dict.keys():
        BC2Dict[BC].close()


run()

