##################################
#                                #
# Last modified 05/01/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s end1 end2 out_prefix' % sys.argv[0]
        print '   Note: this script relies on read IDs ending on /1 and /2'
        sys.exit(1)

    end1 = sys.argv[1]
    end2 = sys.argv[2]
    outend1 = open(sys.argv[3]+'.end1.fastq','w')
    outend2 = open(sys.argv[3]+'.end2.fastq','w')
    outunpaired = open(sys.argv[3]+'.unpaired.fastq','w')

    End1IDList=[]
    End2IDList=[]

    i=0
    pos=1
    input_stream = open(end1)
    for line in input_stream:
        i+=1
        if i % 20000000 == 0:
            print str(i/4000000) + 'M reads processed in end1'
        if pos==1 and line.startswith('@'):
            readID = line.split('/1\n')[0]
            End1IDList.append(readID)
            pos=2
            continue
        if pos==1 and line[0] != '@':
            print 'fastq broken'
            sys.exit(1)
        if pos==2:
            pos=3
            continue
        if pos==3:
            pos=4
            continue
        if pos==4:
            pos=1
            continue

    i=0
    pos=1
    input_stream = open(end2)
    for line in input_stream:
        i+=1
        if i % 20000000 == 0:
            print str(i/4000000) + 'M reads processed in end2'
        if pos==1 and line.startswith('@'):
            readID = line.split('/2\n')[0]
            End2IDList.append(readID)
            pos=2
            continue
        if pos==1 and line[0] != '@':
            print 'fastq broken'
            sys.exit(1)
        if pos==2:
            pos=3
            continue
        if pos==3:
            pos=4
            continue
        if pos==4:
            pos=1
            continue

    commonIDs = Set.intersection(Set(End1IDList),Set(End2IDList))    
    commonIDs = list(commonIDs)

    print len(End1IDList), len(End2IDList), len(commonIDs)

    CommonIDDict = {}
    for readID in commonIDs:
        CommonIDDict[readID] = 0

    i=0
    pos=1
    input_stream = open(end1)
    for line in input_stream:
        i+=1
        if i % 20000000 == 0:
            print str(i/4000000) + 'M reads outputted from end1'
        if pos==1 and line.startswith('@'):
            readID = line.split('/1\n')[0]
            if CommonIDDict.has_key(readID):
                Common=True
                outend1.write(line)
            else:
                outunpaired.write(line)
                Common=False
            pos=2
            continue
        if pos==2:
            if Common:
                outend1.write(line)
            else:
                outunpaired.write(line)
            pos=3
            continue
        if pos==3:
            if Common:
                outend1.write(line)
            else:
                outunpaired.write(line)
            pos=4
            continue
        if pos==4:
            if Common:
                outend1.write(line)
            else:
                outunpaired.write(line)
            pos=1
            continue

    i=0
    pos=1
    input_stream = open(end2)
    for line in input_stream:
        i+=1
        if i % 20000000 == 0:
            print str(i/4000000) + 'M reads outputted from end2'
        if pos==1 and line.startswith('@'):
            readID = line.split('/2\n')[0]
            if CommonIDDict.has_key(readID):
                Common=True
                outend2.write(line)
            else:
                outunpaired.write(line)
                Common=False
            pos=2
            continue
        if pos==2:
            if Common:
                outend2.write(line)
            else:
                outunpaired.write(line)
            pos=3
            continue
        if pos==3:
            if Common:
                outend2.write(line)
            else:
                outunpaired.write(line)
            pos=4
            continue
        if pos==4:
            if Common:
                outend2.write(line)
            else:
                outunpaired.write(line)
            pos=1
            continue

    outend1.close()
    outend2.close()
    outunpaired.close()

run()

