##################################
#                                #
# Last modified 06/30/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import string
from sets import Set


def run():

    if len(sys.argv) < 5:
        print 'usage: python %s Pouya_file motif_prefix radius regioncalls outputfilename' % sys.argv[0]
        sys.exit(1)
    
    getallsites = sys.argv[1]
    motif_prefix = sys.argv[2]
    radius = int(sys.argv[3])
    regioncalls = sys.argv[4]
    outfilename = sys.argv[5]

    regionListSorted={}
    regionDict={}
    linelist=open(regioncalls)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[0]=='chr':
            continue
        chr=fields[0]
        try:
            start=int(fields[1])
            stop=int(fields[2])
        except:
            print 'skipping', line
        PValue=float(fields[6])
        peak=start+int(fields[9])
        regionListSorted[(PValue,chr,start,stop,peak)]=[]
        ID=chr+':'+str(start)+'-'+str(stop)
        if regionDict.has_key(chr):
             pass
        else:
             regionDict[chr]=[]
        regionDict[chr].append((PValue,chr,start,stop,peak))

    outfile = open(outfilename, 'w')
    
    problematic=0
    linelist=open(getallsites)
    i=0
    for line in linelist:
        i+=1
        if i % 1000000 == 0:
            print i, 'lines processed'
        fields=line.strip().split(' ')
        if fields[0].startswith(motif_prefix):
            pass
        else:
            continue 
        chr=fields[1]
        motifPosLeft=int(fields[2])
        motifPosRight=int(fields[3])
        motifPos=int((motifPosRight+motifPosLeft)/2.)
        for (PValue,chr,start,stop,peak) in regionDict[chr]:
             if math.fabs(motifPos-peak) < radius:
                 distance=motifPos-peak
                 regionListSorted[(PValue,chr,start,stop,peak)].append(distance)
                 break

    keys=regionListSorted.keys()
    keys.sort()
    keys.reverse()

    i=1
    present=0
    for (PValue,chr,start,stop,peak) in keys:
        i+=1
        if len(regionListSorted[(PValue,chr,start,stop,peak)])==0:
            continue
        for distance in regionListSorted[(PValue,chr,start,stop,peak)]:
            if math.fabs(distance) <= radius:
                outline=str(i)+'\t'+str(distance)
                outfile.write(outline+'\n')
                MotifPresent=True
        present+=1

    noMotif=len(keys)-present

    outline='#No motif found in:\n' 
    outfile.write(outline)
    outline=str(noMotif)+'/'+str(len(keys))+'\n'
    outfile.write(outline)
    outfile.close()
            
run()
