##################################
#                                #
# Last modified 01/14/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s list_of_wig_files chrom.sizes averaging_radius outputfilename [-chr chr1,chr2...chrN]' % sys.argv[0]
        print '       list_of_wig_files format: label <tab> filename'
        print '       This script will take a number of wig files and will output a table with the datasets on the X and the sorted positions on the Y (averaged across the indicated radius)'
        sys.exit(1)
    
    doChrSubSet=False
    if '-chr' in sys.argv:
        doChrSubSet=True
        chrToKeep={}
        ChrList = sys.argv[sys.argv.index('-chr')+1].split(',')
        for chr in ChrList:
            chrToKeep[chr]=''
        print 'will output the following chromosomes:', chrToKeep.keys()

    inputfilename = sys.argv[1]
    chrom_sizes = sys.argv[2]
    radius = int(sys.argv[3])
    outfilename = sys.argv[4]

    chrSizeDict={}

    ChrDict={}
    outfile = open(outfilename, 'w')
    linelist = open(chrom_sizes)
    for line in linelist:
        fields = line.strip().split('\t')
        chr = fields[0]
        size = int(fields[1])
        chrSizeDict[chr]=size
        if doChrSubSet:
            if chrToKeep.has_key(chr):
                pass
            else:
                continue
        ChrDict[chr]={}
        for i in range(0,size,radius):
            ChrDict[chr][i]=0

    chromosomes = ChrDict.keys()
    chromosomes.sort()
    outline = '#Dataset'
    for chr in chromosomes:
        keys = ChrDict[chr].keys()
        keys.sort()
        for i in keys:
            outline = outline + '\t' + chr + ':' + str(i)
    outfile.write(outline + '\n')

    linelist1 = open(inputfilename)
    for line1 in linelist1:
        fields=line1.split('\n')[0].split('\t')
        label = fields[0]
        print label
        linelist = open(fields[1])
        coverageDict={}
        outline = label 
        k=0
        for line in linelist:
            k+=1
            if k % 10000000 == 0:
                print k/1000000, 'M lines processed'
            fields=line.replace(' ','\t').strip().split('\t')
            chr=fields[0]
            if ChrDict.has_key(chr):
                pass
            else:
                continue
            left=int(fields[1])
            right=int(fields[2])
            if coverageDict.has_key(chr):
                pass
            else:
                coverageDict[chr]={}
            for i in range(left,right):
                coverageDict[chr][i]=fields[3]
        for chr in chromosomes:
            keys = ChrDict[chr].keys()
            keys.sort()
            for i in keys:
                score=0.0
                for j in range(i,min(i+radius,chrSizeDict[chr])):
                    if coverageDict.has_key(chr) and coverageDict[chr].has_key(j):
                        score += float(coverageDict[chr][j])
                outline = outline + '\t' + str(score/radius)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
