##################################
#                                #
# Last modified 12/04/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s <input bowtie_end1_filename1[,filename2,,...,filenameN> <input bowtie_end2_filename1[,filename2,filename3,...,filenameN> <outfilename> [-first <number read pairs>] [-fromspikes <spike filename (any format where each line starts with the spike name)>] [-casava1.8] [-HiSeqBarcode] [-end3]' % sys.argv[0]
        sys.exit(1)

    end1inputfilenames = sys.argv[1]
    end2inputfilenames = sys.argv[2]

    doCasava=False
    if '-casava1.8' in sys.argv:
        doCasava=True

    doHiSeqBarcode=False
    if '-HiSeqBarcode' in sys.argv:
        doHiSeqBarcode=True

    doEnd3=False
    if 'end3' in sys.argv:
        doEnd3=True

    outputfilename = sys.argv[3]

    print 'output to ', outputfilename

    FirstN=1000000000000000000000000000000000000000
    if '-first' in sys.argv:
        FirstN = int(sys.argv[sys.argv.index('-first')+1])
        print 'will look at the first', FirstN, 'read pairs'

    doSpikes=False
    if '-fromspikes' in sys.argv:
        spikesfilename = sys.argv[sys.argv.index('-fromspikes')+1]
        doSpikes=True
        print 'will estimate distribution from spikes'

    outfile = open(outputfilename, 'w')
    outfile.write('#length\tnumber\n')

    if doSpikes:
        listoflines=open(spikesfilename)
        SpikeDict={}
        for line in listoflines:
            spike=line.strip().split('\t')[0].split(' ')[0]
            SpikeDict[spike]=''

    DistributionDict={}
    end1inputfilenameList=end1inputfilenames.split(',')
    end2inputfilenameList=end2inputfilenames.split(',')
    ReadDict={}
    for inputfilename in end1inputfilenameList:
        lineslist = open(inputfilename)
        i=0
        j=0
        for line in lineslist:
            if i % 5000000 == 0:
                print str(i/1000000) + 'M alignments processed in', inputfilename
            i+=1
            fields = line.strip().split('\t')
            if doSpikes:
                if SpikeDict.has_key(fields[2]):
                    pass
                else:
                    continue
            if len(fields)<6:
                continue
            if fields[6]!='0':
                continue
            readID=fields[0]
            if readID.endswith('/2'):
                continue
            if doEnd3 and readID.endswith('/3'):
                continue
            if doCasava:
                readID=readID.split('_1:')[0]
            elif doHiSeqBarcode:
                readID = readID.split(' 1')[0]
            else:
                readID=readID.split('/')[0]
            chr=fields[2]
            pos=int(fields[3])
            if j<FirstN:
                chr=fields[2]
                pos=int(fields[3])
                ReadDict[readID]=(chr,pos)
                j+=1
                if j % 5000000 == 0:
                    print str(j/1000000) + 'M unique end1 reads found'
            else:
                print 'found', FirstN, ' first end reads, searching for paired reads'
                break
    for inputfilename in end2inputfilenameList:
        lineslist = open(inputfilename)
        i=0
        for line in lineslist:
            if i % 5000000 == 0:
                print str(i/1000000) + 'M alignments processed in', inputfilename
            i+=1
            fields = line.strip().split('\t')
            if len(fields) <= 6:
                print fields
                print line
            if doSpikes:
                if SpikeDict.has_key(fields[2]):
                    pass
                else:
                    continue
            if fields[6]!='0':
                continue
            readID=fields[0]
            if readID.endswith('/1'):
                continue
            if doCasava:
                FirstEndReadID=readID.split('_2:')[0]
            elif doHiSeqBarcode:
                FirstEndReadID=readID.split(' 2')[0]
            else:
                FirstEndReadID=readID.split('/')[0]
            if ReadDict.has_key(FirstEndReadID):
                chr=fields[2]
                pos2=int(fields[3])
                readlength=len(fields[4])
                if ReadDict[FirstEndReadID][0]!=chr:
                    continue
            else:
                continue
            pos1=ReadDict[FirstEndReadID][1]
            insertlength=math.fabs(pos2-pos1)+readlength
            if DistributionDict.has_key(insertlength):
                DistributionDict[insertlength]+=1
            else:
                DistributionDict[insertlength]=1

    DistributionDictList=DistributionDict.keys()
    DistributionDictList.sort()
    for length in DistributionDictList:
        outline=str(length)+'\t'+str(DistributionDict[length])+'\n'
        outfile.write(outline)

    outfile.close()

run()

