##################################
#                                #
# Last modified 05/12/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import string
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s list_of_input_files kmer_size minimum abundance maximum_abundance' % sys.argv[0]
        print '     list_of_input_files format: <filename> <tab> <type ("paired" or "unpaired" where paired means in velvet format> <tab> <outfile_prefix>'
        print '     the script will take all reads and put them together into the same script separated by N characters, then discard all k-mers only found once; this means all Ns must have been filtered out prior to running the script'
        print '     the script will output reads which the kmer removal leaves withot a pair in a separate fil, ending with ".unpaired.fastq", while the paired reads will remain as ".paired.fastq"'
        sys.exit(1)

    input = sys.argv[1]
    k = int(sys.argv[2])
    minkCount = int(sys.argv[3])
    maxkCount = int(sys.argv[4])

    GiantKmerDict={}
 
    input_stream = open(input)
    for line1 in input_stream:
        fields = line1.strip().split('\t')
        file = fields[0]
        i=0.0
        pos=1
        print fields
        if sys.getsizeof(GiantKmerDict) > 250000000000:
            print 'too much memory taken, exiting'
            sys.exit(1)
        linelist = open(file)
        for line in linelist:
            i+=1
            if i % 4000000 == 0:
                print str(i/4000000) + 'M reads processes from', file
                print 'memory footprint:', sys.getsizeof(GiantKmerDict)
            if pos==1 and line.startswith('@'):
                pos=2
                continue
            if pos==2:
                sequence = line.strip()
                for p in range(len(sequence)-k):
                    kmer = sequence[p:p+k]
                    if GiantKmerDict.has_key(kmer):
                        GiantKmerDict[kmer]+=1
                    else:
                        reversekmer = getReverseComplement(kmer)
                        if GiantKmerDict.has_key(reversekmer):
                            GiantKmerDict[reversekmer]+=1
                        else:
                            GiantKmerDict[kmer]=1
                pos=3
                continue
            if pos==3:
                pos=4
                continue
            if pos==4:
                pos=1
                continue

    print 'finished inputting k-mers, size:', sys.getsizeof(GiantKmerDict)

    print len(GiantKmerDict.keys())

    KmersToBeFiltered=0
    KmersToBeRetained=0
    for kmer in GiantKmerDict.keys():
        if GiantKmerDict[kmer] < minkCount or GiantKmerDict[kmer] > maxkCount:
            del GiantKmerDict[kmer]
            KmersToBeFiltered+=1
        else:
            KmersToBeRetained+=1

    gc.collect()

    print len(GiantKmerDict.keys())

    print 'finished removing kmers to be retained from hash table', sys.getsizeof(GiantKmerDict)
    print 'kmers to be retained:', KmersToBeRetained
    print 'kmers to be filtered:', KmersToBeFiltered

    input_stream = open(input)
    for line1 in input_stream:
        fields = line1.strip().split('\t')
        file = fields[0]
        type = fields[1]
        outprefix = fields[2]
        print fields
        if type == 'paired':
            outfile_paired = open(outprefix.split('.fastq')[0] + '.paired.fastq','w')
            outfile_unpaired = open(outprefix.split('.fastq')[0] + '.unpaired.fastq','w')
            print outfile_paired, outfile_unpaired
            i=0
            pos=1
            discardedPairs = 0
            discardedReads = 0
            linelist = open(file)
            for line in linelist:
                i+=1
                if i % 2000000 == 0:
                    print str(i/4000000) + 'M reads outputted from ', file
                if pos==1:
                    readID1 = line.strip()
                    pos=2
                    continue
                if pos==2:
                    sequence1 = line.strip()
                    pos=3
                    continue
                if pos==3:
                    pos=4
                    continue
                if pos==4:
                    pos=5
                    quality1 = line.strip()
                    continue
                if pos==5:
                    readID2 = line.strip()
                    pos=6
                    continue
                if pos==6:
                    sequence2 = line.strip()
                    pos=7
                    continue
                if pos==7:
                    pos=8
                    continue
                if pos==8:
                    pos=1
                    quality2 = line.strip()
                    KmerNotMatchingRequirements1=True
                    for p in range(len(sequence1)-k):
                        kmer = sequence1[p:p+k]
                        if GiantKmerDict.has_key(kmer) or GiantKmerDict.has_key(getReverseComplement(kmer)):
                            KmerNotMatchingRequirements1=False
                            break
                    KmerNotMatchingRequirements2=True
                    for p in range(len(sequence2)-k):
                        kmer = sequence2[p:p+k]
                        if GiantKmerDict.has_key(kmer) or GiantKmerDict.has_key(getReverseComplement(kmer)):
                            KmerNotMatchingRequirements2=False
                            break
                    if KmerNotMatchingRequirements1 and KmerNotMatchingRequirements2:
                        discardedPairs+=1
                        continue
                    elif KmerNotMatchingRequirements1 or KmerNotMatchingRequirements2:
                        discardedReads+=1
                        if KmerNotMatchingRequirements1:
                            outfile_unpaired.write(readID2 + '\n')
                            outfile_unpaired.write(sequence2 + '\n')
                            outfile_unpaired.write('+' + '\n')
                            outfile_unpaired.write(quality2 + '\n')
                        if KmerNotMatchingRequirements2:
                            outfile_unpaired.write(readID1 + '\n')
                            outfile_unpaired.write(sequence1 + '\n')
                            outfile_unpaired.write('+' + '\n')
                            outfile_unpaired.write(quality1 + '\n')
                        continue
                    else:
                        outfile_paired.write(readID1 + '\n')
                        outfile_paired.write(sequence1 + '\n')
                        outfile_paired.write('+' + '\n')
                        outfile_paired.write(quality1 + '\n')
                        outfile_paired.write(readID2 + '\n')
                        outfile_paired.write(sequence2 + '\n')
                        outfile_paired.write('+' + '\n')
                        outfile_paired.write(quality2 + '\n')
                    continue
            outfile_unpaired.close()
            outfile_paired.close()
            print 'discarded reads', discardedReads, 'out of', i/4, 'reads'
            print 'discarded pairs', discardedPairs, 'out of', i/8, 'pairs'
        if type == 'unpaired':
            outfile_unpaired = open(outprefix.split('.fastq')[0] + '.unpaired.fastq','w')
            i=0
            pos=1
            discarded = 0
            linelist = open(file)
            for line in linelist:
                i+=1
                if i % 2000000 == 0:
                    print str(i/4000000) + 'M reads outputted from ', file
                if pos==1 and line.startswith('@'):
                    readID = line.strip()
                    pos=2
                    continue
                if pos==2:
                    sequence = line.strip()
                    pos=3
                    continue
                if pos==3:
                    pos=4
                    continue
                if pos==4:
                    pos=1
                    quality = line.strip()
                    KmerNotMatchingRequirements=True
                    for p in range(len(sequence)-k):
                        kmer = sequence[p:p+k]
                        if GiantKmerDict.has_key(kmer) or GiantKmerDict.has_key(getReverseComplement(kmer)):
                            KmerNotMatchingRequirements=False
                            break
                    if KmerNotMatchingRequirements:
                        discarded+=1
                        continue
                    else:
                        outfile_unpaired.write(readID + '\n')
                        outfile_unpaired.write(sequence + '\n')
                        outfile_unpaired.write('+' + '\n')
                        outfile_unpaired.write(quality + '\n')
                    continue
            outfile_unpaired.close()
            print 'discarded', discarded, 'out of', i/4, 'reads'

run()

