##################################
#                                #
# Last modified 10/08/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam


def run():

    if len(sys.argv) < 5:
        print 'usage: python %s mappability_wiggle signal_wiggle min_radius flanking_regions_size outfilename' % sys.argv[0]
        print 'Note1: it is assumed that the signal wiggle is in raw reads, not RPM'
        print 'Note2: it is assumed that the mappability wiggle is a combined 5p on, i.e. a signal of 2 indicates perfect mappability in both directions'
        print 'Note3: prefect mappability will be required for all bases in the region and the two flanking regions'
        sys.exit(1)
    
    mappability = sys.argv[1]
    wiggle = sys.argv[2]
    size = int(sys.argv[3])
    flanking_region_size = int(sys.argv[4])
    outfilename = sys.argv[5]

    outfile = open(outfilename, 'w')
    outline = '#chr\tleft\tright\tlength\tleft_' + str(flanking_region_size) + '_average_reads\tright_' + str(flanking_region_size) + '_average_reads'
    outfile.write(outline + '\n')

    SignalDict={}
    k=0
    linelist = open(wiggle)
    for line in linelist:
        k+=1
        if k % 1000000 == 0:
            print k, 'lines processed'
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        score = int(float(fields[3]))
        if SignalDict.has_key(chr):
            pass
        else:
            print chr
            SignalDict[chr]={}
        for i in range(left,right):
            SignalDict[chr][i]=score

    print 'finished importing scores'

    CandidateRegions = []

    linelist = open(mappability)
    for line in linelist:
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        score = int(fields[3])
        if score < 2:
            continue
        if (right - left) < (2*flanking_region_size + size):
            continue
        NoSignal = False
        for i in range(left + flanking_region_size, right - flanking_region_size):
            if SignalDict[chr].has_key(i) and SignalDict[chr][i] > 0:
                if NoSignal:
                    if i-initial >= size:
                        CandidateRegions.append((chr,initial,i))
                NoSignal = False
            else:
                if NoSignal:
                    continue
                else:
                    NoSignal = True
                    initial = i

    print CandidateRegions

    print 'finished compiling candidate regions list'

    WantedNTSignalDict = {}
    for (chr,left,right) in CandidateRegions:
        if WantedNTSignalDict.has_key(chr):
            pass
        else:
            WantedNTSignalDict[chr]={}
        for i in range(left-flanking_region_size,left):
            WantedNTSignalDict[chr][i]=0.0
        for i in range(right,right+flanking_region_size):
            WantedNTSignalDict[chr][i]=0.0

    k=0
    linelist = open(wiggle)
    for line in linelist:
        k+=1
        if k % 1000000 == 0:
            print k, 'lines processed, second pass'
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        score = int(float(fields[3]))
        for i in range(left,right):
            if WantedNTSignalDict.has_key(chr) and WantedNTSignalDict[chr].has_key(i):
                WantedNTSignalDict[chr][i] += score

    

    for (chr,left,right) in CandidateRegions:
        outline = chr + '\t' + str(left)  + '\t' + str(right) + '\t' + str(right-left)
        AverageLeft = 0
        AverageRight = 0
        for i in range(left-flanking_region_size,left):
            AverageLeft += (WantedNTSignalDict[chr][i]/flanking_region_size)
        for i in range(right,right+flanking_region_size):
            AverageRight += (WantedNTSignalDict[chr][i]/flanking_region_size)
        outline = outline + '\t' + str(AverageLeft) + '\t' + str(AverageRight)
        outfile.write(outline + '\n')

    outfile.close()
            
run()
        