##################################
#                                #
# Last modified 2018/09/25       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gzip

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s fastq sgRNAs_file sequenceFieldID labelFieldID outfilename [-missingFirstBase]' % sys.argv[0]
        print '\tthe script accepts stdin and .gz as input for fastq and .gz for the sgRNAs_file'
        print '\tNote!!!: it is assumed that the sequencing data corresponds to the reverse complement of the sgRNAs'
        sys.exit(1)

    fastq = sys.argv[1]
    inputfilename = sys.argv[2]
    seqFieldID = int(sys.argv[3])
    labelFieldID = int(sys.argv[4])
    outfilename = sys.argv[5]

    SeqDict = {}
    SeqDict1MM = {}

    minLen = 1e10
    maxLen = 0

    doMFB = False
    if '-missingFirstBase' in sys.argv:
        doMFB = True

    if inputfilename.endswith('.gz'):
        lineslist = gzip.open(inputfilename)
    else:
        lineslist = open(inputfilename)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        seq = fields[seqFieldID].upper()
        if doMFB:
            seq = seq[1:]
        seq = getReverseComplement(seq)
        label = fields[labelFieldID]
        if SeqDict.has_key(seq):
            SeqDict[seq] += ',,' + label
        else:
            SeqDict[seq] = label
        if len(seq) < minLen:
            minLen = len(seq)
        if len(seq) > maxLen:
            maxLen = len(seq)

    CountDict = {}

    for seq in SeqDict.keys():
        CountDict[SeqDict[seq]] = 0
        for i in range(len(seq)):
            for B in ['A','C','G', 'N']:
                newSeq = seq[0:i] + B + seq[i+1:]
                if SeqDict1MM.has_key(newSeq):
                    continue
                else:
                    SeqDict1MM[newSeq] = SeqDict[seq]

    print 'finished inputting reference guides'
    
    TotalCounts = 0

    if fastq == '-':
        lineslist  = sys.stdin
    elif fastq.endswith('.gz'):
        lineslist  = gzip.open(fastq)
    else:
        lineslist  = open(fastq)
    i = 1
    L = 0
    for line in lineslist:
        L += 1
        if L % 4000000 == 0:
            print str(L/4000000) + 'M reads processeed'
        if line.startswith('@') and i == 1:
            i = 2
            continue
        if i == 2:
            seq = line.strip()
            HasMatch = False
            K = min(maxLen,len(seq))
#            print seq
            while K >= minLen:
#                print seq[0:K], K
                if SeqDict.has_key(seq[0:K]):
                    label = SeqDict[seq[0:K]]
#                    print label
                    CountDict[label] += 1
                    TotalCounts += 1
                    HasMatch = True
                    break
                K = K - 1
            if not HasMatch:
                K = maxLen
#                print seq[0:K], K, 'mm'
                while K >= minLen:
                    if SeqDict1MM.has_key(seq[0:K]):
                        label = SeqDict1MM[seq[0:K]]
#                        print label, 'mm'
                        CountDict[label] += 1
                        TotalCounts += 1
                        HasMatch = True
                        break
                    K = K - 1
            i = 3
        if i == 3:
            i = 4 
            continue
        if i == 4:
            i = 1 
            continue

    outfile = open(outfilename, 'w')

    outline = '#sgRNA\tcounts\tRPM'
    outfile.write(outline + '\n')

    labels = CountDict.keys()
    labels.sort()

    RPMNormFactor = TotalCounts/(1000000 + 0.0)

    for label in labels:
        counts = CountDict[label]
        RPM = counts/RPMNormFactor
        outline = label + '\t' + str(counts) + '\t' + str(RPM)
        outfile.write(outline + '\n')
    
    outfile.close()

run()

