##################################
#                                #
# Last modified 2019/05/17       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
import h5py
import numpy as np    
import os

# https://github.com/kundajelab/basepair/blob/master/basepair/cli/imp_score.py#L311

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s imp-score outprefix' % sys.argv[0]
        print '\tby default task1 will be printed out'


    print 'need to fix masking first'
    sys.exit(1)

    input = sys.argv[1]
    outprefix = sys.argv[2]

    fMfile = h5py.File(input, 'r')

    outfile1 = open(outprefix + '.imp-score0.wig','w')
    outfile2 = open(outprefix + '.imp-score1.wig','w')

    ChrDict = {}

    chromosomes = fMfile['/metadata/range/chr']
    startpositions = fMfile['/metadata/range/start']
    endpositions = fMfile['/metadata/range/end']
    impscores0 = fMfile['/hyp_imp/task1/weighted/0']
    impscores1 = fMfile['/hyp_imp/task1/weighted/1']
    i=0
    for block0 in impscores0:
        chr = chromosomes[i]
        start = startpositions[i]
        if ChrDict.has_key(chr):
            pass
        else:
            ChrDict[chr] = {}
        block1 = impscores1[i]
        j=0
        for scores in block0:
            s0 = sum(scores)
            s1 = sum(block1[j])
            j+=1
            ChrDict[chr][start+j] = (s0,s1)
        i+=1
        if i % 100 == 0:
            print i, 'blocks processed'

    chrs = ChrDict.keys()
    chrs.sort()

    for chr in chrs:
        print chr
        positions = ChrDict[chr].keys()
        positions.sort()
        for pos in positions:
            (s1,s2) = ChrDict[chr][pos]
            outline = chr + '\t' + str(pos) + '\t' + str(pos + 1)
            if float(s1) != 0:
                outfile1.write(outline + '\t' + str(s1) + '\n')
            if float(s2) != 0:
                outfile2.write(outline + '\t' + str(s2) + '\n')

    outfile1.close()
    outfile2.close()
    
run()
