##################################
#                                #
# Last modified 07/23/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s inputfilename first-end-bp bp-to-be-trimmed second-end-bp outfileprefix [-fromSRA]' % sys.argv[0]
        print '\tnote: the -fromSRA option will remove the lenght part of read IDs' 
        print '\tnote: use - instead of an input filename to indicate standard input' 
        sys.exit(1)

    doSRA=False
    if '-fromSRA' in sys.argv:
        doSRA=True
        print 'will treat reads as coming from SRA'

    inputfilename = sys.argv[1]
    end1 = int(sys.argv[2])
    trim = int(sys.argv[3])
    end2 = int(sys.argv[4])
    outputfilename1 = sys.argv[5] + '.end1.' +str(end1)+'mers.fastq'
    outputfilename2 = sys.argv[5] + '.end2.' +str(end2)+'mers.fastq'

    doStdInput = False
    if inputfilename == '-':
        doStdInput = True

    outfile1 = open(outputfilename1, 'w')
    outfile2 = open(outputfilename2, 'w')

    i=0
    pos=1
    scoresNext=False
    seqNext=False
    shorter=0
    if doStdInput:
        input_stream = sys.stdin
    else:
        input_stream = open(inputfilename)
    for line in input_stream:
        previous=line
        if pos==1:
            if line.startswith('@'):
                if doSRA:
                    outfile1.write(line.strip().split(' length')[0].replace(' ','_')+'/1\n')
                    outfile2.write(line.strip().split(' length')[0].replace(' ','_')+'/2\n')
                else:
                    outfile1.write(line.strip().replace(' ','_')+'/1\n')
                    outfile2.write(line.strip().replace(' ','_')+'/2\n')
                pos=2
                continue
            else:
                print 'invalid read', line
                break
        if pos==2:
            i=i+1
            if i % 10000000 == 0:
                print str(i/1000000) + 'M reads processed'
            if len(line.strip())<trim:
                shorter+=1
                pos=3
                continue
            outfile1.write(line[0:end1].replace('.','N')+'\n')
            outfile2.write(line[end1+trim:end1+trim+end2].replace('.','N')+'\n')
            pos=3
            continue
        if pos==3 and line.startswith('+'):
            outfile1.write('+\n')
            outfile2.write('+\n')
            pos=4
            continue
        if pos==4:
            if len(line.strip())<trim:
                pos=1
                continue
            outfile1.write(line[0:end1]+'\n')
            outfile2.write(line[end1+trim:end1+trim+end2]+'\n')
            pos=1
            continue

    outfile1.close()
    outfile2.close()

    if shorter>0:
        print shorter, 'sequences shorter than desired length, skipped'
run()

