##################################
#                                #
# Last modified 2016/10/25       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s bed chrFieldID mappability-wig score outifle' % sys.argv[0]
        sys.exit(1)

    bed = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    wig = sys.argv[3]
    score = float(sys.argv[4])
    outputfilename = sys.argv[5]

    BEDCovDict = {}
    lineslist = open(bed)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        start = int(fields[chrFieldID + 1])
        stop = int(fields[chrFieldID + 2])
        if BEDCovDict.has_key(chr):
            pass
        else:
            BEDCovDict[chr]={}
        for i in range(start,stop):
            BEDCovDict[chr][i]=0

    print 'finished importing regions'

    lineslist = open(wig)
    AboveScore=0
    i=0
    for line in lineslist:
        i+=1
        if i % 10000000 == 0:
            print str(i/1000000) + 'M lines processed'
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        start = int(fields[1])
        stop = int(fields[2])
        if BEDCovDict.has_key(chr):
            pass
        else:
            continue
        for j in range(start,stop):
            if BEDCovDict[chr].has_key(j):
                if float(fields[3]) >= score:
                    BEDCovDict[chr][j] = 1

    outfile = open(outputfilename, 'w')

#    for chr in BEDCovDict.keys():
#        Total = 0
#        AboveThreshold = 0
#        Total += len(BEDCovDict[chr].keys())
#        for i in BEDCovDict[chr].keys():
#            AboveThreshold += BEDCovDict[chr][i]
#        print chr, Total, AboveThreshold

    lineslist = open(bed)
    for line in lineslist:
        if line.startswith('#'):
            outline = line.strip() + '\t' + 'Mappable_fraction'
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        start = int(fields[chrFieldID + 1])
        stop = int(fields[chrFieldID + 2])
        mappable = 0.0
        for i in range(start,stop):
            mappable += BEDCovDict[chr][i]
#        print mappable, stop, start, stop-start, mappable/(stop-start), mappable/(stop-start + 0.0)
        outline = line.strip() + '\t' + str(mappable/(stop-start + 0.0))
        outfile.write(outline + '\n')        
    
    outfile.close()

run()

