##################################
#                                #
# Last modified 2018/01/02       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import subprocess
import regex

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fastq  <adapter filename> outfilename [-adapter sequence] [-MM number] [-findFirst bp] [-polyA bp] [-retain3p bp] [-minLen bp] [-SMART bp]' % sys.argv[0]
        print '\tUse - instead of a fastq name if you want to capture standard input'
        print '\tUse -stdout instead of filename if you want to print to standard output'
        print '\tBy default the script expects read to end with the apdater sequence; this may not always be the case, so ue the -findFirst option if you want the script to find the last occurence of the first X bp of the adapter and trim the read there'
        print '\tUse the -MM option if you want to allow mismatches in the adapter; this option only works together with the [-findFirst] option'
        print '\tUse the -SMART option if the SMART protocol was used; this option will trim the indicated number of bases from the 5p end'
        sys.exit(1)

    fastq = sys.argv[1]
    outputfilename = sys.argv[3]

    if '-adapter' in sys.argv:
        adapter = sys.argv[sys.argv.index('-adapter')+1]
    else:
        adapterfilename = sys.argv[2]
        listoflines = open(adapterfilename)
        lineslist = listoflines.readlines()
        adapter = lineslist[0].split('\n')[0].split('\t')[0].split(' ')[0].upper().replace('U','T')
    adapterlength = len(adapter)

    doStdOut = False
    if outputfilename == '-stdout':
        doStdOut = True

    if not doStdOut:
        outfile = open(outputfilename, 'w')

    doSMART = False
    if '-SMART' in sys.argv:
        doSMART = True
        SMART = int(sys.argv[sys.argv.index('-SMART') + 1])

    doRetain3P = False
    if '-retain3p' in sys.argv:
        doRetain3P = True
        R3P = int(sys.argv[sys.argv.index('-retain3p') + 1])

    doPA = False
    if '-polyA' in sys.argv:
        doPA = True
        PALen = int(sys.argv[sys.argv.index('-polyA') + 1])
        PA = ''
        for i in range(PALen):
            PA += 'A'

    doFirst = False
    if '-findFirst' in sys.argv:
        doFirst = True
        firstBP = int(sys.argv[sys.argv.index('-findFirst')+1])
        MM = 0
        if '-MM' in sys.argv:
            doMM = True
            MM = int(sys.argv[sys.argv.index('-MM')+1])
            mM = regex.compile(adapter[0:firstBP] + '{e<=' + str(MM) + '}')

    doML = False
    if '-minLen' in sys.argv:
        doML = True
        MinLen = int(sys.argv[sys.argv.index('-minLen')+1])

    ReadLengthDict={}

    if fastq == '-':
        lineslist  = sys.stdin
    else:
        lineslist  = open(fastq)
    i=1
    lIndex=0
    for line in lineslist:
        lIndex+=1
        if lIndex % 20000000 == 0:
            if not doStdOut:
                print int(lIndex/4.0), 'reads processeds' 
        if i == 1 and line[0]!='@':
            if not doStdOut:
                print 'fastq file broken, exiting'
            sys.argv(1)
        if i == 1 and line[0]=='@':
            readID = line.strip()
            i=2
            continue
        if i == 2:
            sequence = line.strip()
            FoundFirst = False
            if doSMART:
                sequence = sequence[SMART:]
            if doFirst:
                if len(mM.findall(sequence)) > 0:
                    newsequence = sequence.rpartition(mM.findall(sequence)[-1])[0]
                else:
                    newsequence = ''
                if newsequence != '':
                    sequence = newsequence
                    FoundFirst = True
            if not FoundFirst:
                for pos in range(len(sequence)-adapterlength,len(sequence)):
                    if sequence[pos:len(sequence)] == adapter[0:len(sequence[pos:len(sequence)])]:
                        sequence = sequence[0:pos]
                        break
            if doPA:
                newsequence = sequence.partition(PA)[0]
                sequence = newsequence
            i=3
            continue
        if i == 3 and line[0]=='+':
            i = 4
            continue
        if i == 4:
            if ReadLengthDict.has_key(len(sequence)):
                ReadLengthDict[len(sequence)]+=1
            else:
                ReadLengthDict[len(sequence)]=1
            quality = line.strip()
            if doSMART:
                quality = quality[SMART:]
            quality = quality[0:len(sequence)]
            if not doStdOut:
                if doRetain3P:
                    if doML and len(sequence[-R3P:]) < MinLen:
                        pass
                    elif len(sequence) < R3P:
                        pass
                    else:
                        outfile.write(readID + '\n')
                        outfile.write(sequence[-R3P:] + '\n')
                        outfile.write('+\n')
                        outfile.write(quality[-R3P:] + '\n')
                else:
                    if doML and len(sequence) < MinLen:
                        pass
                    else:
                        outfile.write(readID + '\n')
                        outfile.write(sequence + '\n')
                        outfile.write('+\n')
                        outfile.write(quality + '\n')
            else:
                if doRetain3P:
                    if doML and len(sequence[-R3P:]) < MinLen:
                        pass
                    elif len(sequence) < R3P:
                        pass
                    else:
                        print readID
                        print sequence[-R3P:]
                        print '+'
                        print quality[-R3P:]
                else:
                    if doML and len(sequence) < MinLen:
                        pass
                    else:
                        print readID
                        print sequence
                        print '+'
                        print quality
            i=1
            continue

    keys=ReadLengthDict.keys()
    keys.sort()
    for k in keys:
        if not doStdOut:
            print k, ReadLengthDict[k]

    if not doStdOut:
        outfile.close()

run()

