##################################
#                                #
# Last modified 12/24/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import pysam

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s SAM/bed regions chrFieldID outputfilename bed | SAM' % sys.argv[0]
        print 'Note: both multi and unique reads counted right now, so the SAM file should have the multi reads filtered out'

        sys.exit(1)
    
    input = sys.argv[1]
    bed = sys.argv[2]
    chrFieldID=int(sys.argv[3])
    outfilename = sys.argv[4]
    SAMorBED = sys.argv[5]

    outfile = open(outfilename, 'w')

    DupPosDict={}

    listoflines = open(input)
    i=0
    if SAMorBED == 'SAM':
        for line in listoflines:
            i+=1
            if i % 1000000 == 0:
                print i, 'lines processed'
            if line[0]=='#':
                continue
            fields=line.split('\n')[0].split('\t')
            chr=fields[2]
            pos=int(fields[3])
            if DupPosDict.has_key(chr):
                pass
            else:
                DupPosDict[chr]={}
            if DupPosDict[chr].has_key(pos):
                pass
            else:
                DupPosDict[chr][pos]=0
            DupPosDict[chr][pos]+=1
    elif SAMorBED == 'bed':
        for line in listoflines:
            i+=1
            if i % 1000000 == 0:
                print i, 'lines processed'
            if line[0]=='#':
                continue
            fields=line.split('\n')[0].split('\t')
            chr=fields[0]
            pos=int(fields[1])
            if DupPosDict.has_key(chr):
                pass
            else:
                DupPosDict[chr]={}
            if DupPosDict[chr].has_key(pos):
                pass
            else:
                DupPosDict[chr][pos]=0
            DupPosDict[chr][pos]+=1
    else:
        print 'inpit file incorrectly specified'
        

    TotalPos=0
    TotalReads=0
    DuplicatedPos=0
    DuplicatedReads=0
        
    for chr in DupPosDict.keys():
        for pos in DupPosDict[chr].keys():
            TotalPos+=1
            TotalReads+=DupPosDict[chr][pos]
            if DupPosDict[chr][pos] > 1:
                DuplicatedPos+=1
                DuplicatedReads+=DupPosDict[chr][pos]

    BedTotalPos=0
    BedTotalReads=0
    BedDuplicatedPos=0
    BedDuplicatedReads=0
                
    listoflines = open(bed)
    for line in listoflines:
        if line[0]=='#':
            continue
        fields=line.split('\n')[0].split('\t')
        chr=fields[chrFieldID]
        left=int(fields[chrFieldID+1])
        right=int(fields[chrFieldID+2])
        for i in range(left,right):
            if DupPosDict[chr].has_key(i):
                BedTotalPos+=1
                BedTotalReads+=DupPosDict[chr][i]
                if DupPosDict[chr][i] > 1:
                    BedDuplicatedPos+=1
                    BedDuplicatedReads+=DupPosDict[chr][i]

    outfile.write('#\tTotal\tIn Regions\n')
    outline='Positions\t' + str(TotalPos) + '\t' + str(BedTotalPos)
    outfile.write(outline + '\n')
    outline='Reads\t' + str(TotalReads) + '\t' + str(BedTotalReads)
    outfile.write(outline + '\n')
    outline='Duplicated Positions\t' + str(DuplicatedPos) + '\t' + str(BedDuplicatedPos)
    outfile.write(outline + '\n')
    outline='Duplicated Reads\t' + str(DuplicatedReads) + '\t' + str(BedDuplicatedReads)
    outfile.write(outline + '\n')
    outfile.close()
   
run()
