##################################
#                                #
# Last modified 2021/09/28       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
from sets import Set
import Levenshtein
import numpy as np

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s SHARE-PERTURB-sgRNA-UMI_output UMIs|reads minCounts minFraction outfilename' % sys.argv[0]
        print '\t the script assumes that cell barcodes have already been annotated with SHARE-seq-barcode-annotate.py or SHARE-seq-barcode-annotate-UG.py'
        print '\t the default [-sgRNAedit] edit distance value is 1'
        print '\t the default [-UMIedit] edit distance value is 1'
        print '\t the script expects the output of PEFastqToTabDelimited.py with annotated barcodes to be streamed in'
        sys.exit(1)

    sgRNAs = sys.argv[1]
    UMIsOrReads = sys.argv[2]
    minCounts = int(sys.argv[3])
    minFraction = float(sys.argv[4])
    outfilename = sys.argv[5]

    BCDict = {}

    lineslist = open(sgRNAs)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        sgRNA = fields[1]
        label = fields[2]
        BC = fields[0]
        UMI = int(fields[3])
        if fields[4] == 'nan':
            print 'skipping:', line.strip()
            continue
        reads = int(fields[4])
        if BCDict.has_key(BC):
            pass
        else:
            BCDict[BC] = []
        BCDict[BC].append((sgRNA,label,UMI,reads))

    outfile = open(outfilename, 'w')
    outline = '#barcode\tsgRNA\tlabel\tcounts\ttotal\tfraction'
    outfile.write(outline + '\n')

    barcodes = BCDict.keys()
    barcodes.sort()

    for BC in barcodes:
        total = 0.0
        for (sgRNA,label,UMI,reads) in BCDict[BC]:
            if UMIsOrReads == 'UMIs':
                total += UMI
            if UMIsOrReads == 'reads':
                total += reads
        if total >= minCounts:
            pass
        else:
            continue
        Fractions = []
        FractionsSG = []
        for (sgRNA,label,UMI,reads) in BCDict[BC]:
            if UMIsOrReads == 'UMIs':
                F = UMI/total
            if UMIsOrReads == 'reads':
                F = reads/total
            Fractions.append(F)
            FractionsSG.append((sgRNA,label,UMI,reads))
        if max(Fractions) < minFraction:
            continue
        maxF = max(Fractions)
        maxFSG = FractionsSG[Fractions.index(maxF)]
        (sgRNA,label,UMI,reads) = maxFSG
        if UMIsOrReads == 'UMIs':
            outline = BC + '\t' + sgRNA + '\t' + label + '\t' + str(UMI) + '\t' + str(total) + '\t' + str(maxF)
        if UMIsOrReads == 'reads':
            outline = BC + '\t' + sgRNA + '\t' + label + '\t' + str(reads) + '\t' + str(total) + '\t' + str(maxF)
        outfile.write(outline + '\n')
            
run()
