##################################
#                                #
# Last modified 2017/10/04       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import gc
import pysam
import string
import math
from sets import Set
import os

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fragment_counts window_size outputfilename [-minCounts n]' % sys.argv[0]
        print '\tNote: Assumed format: singleFieldCoords <tab> <count_fields>, with a header line'
        sys.exit(1)

    readCounts = sys.argv[1]
    window = int(sys.argv[2])
    outputfilename = sys.argv[3]

    outfile = open(outputfilename,'w')

    doMinCounts = False
    if '-minCounts' in sys.argv:
        doMinCounts = True
        minCounts = int(sys.argv[sys.argv.index('-minCounts') + 1])
        print 'will discard merged fragments with less than', minCounts, 'counts in all samples'

    FragmentDict = {}

    if readCounts.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + readCounts
    elif readCounts.endswith('.gz'):
        cmd = 'gunzip -c ' + readCounts
    elif readCounts.endswith('.zip'):
        cmd = 'unzip -p ' + readCounts
    else:
        cmd = 'cat ' + readCounts
    p = os.popen(cmd, "r")
    line = 'line'
    RP = 0
    while line != '':
        line = p.readline().strip()
        fields = line.split('\t')
        RP += 1
        if RP % 1000000 == 0:
            print RP, 'fragments processed'
        if line == '':
            break
        fields = line.strip().split('\t')
        if line.startswith('#'):
            outline = '#'
            for i in range(1,len(fields)):
                outline = outline + '\t' + fields[i]
            outfile.write(outline + '\n')
            continue
        ID = fields[0]
        chr = ID.split(':')[0]
        left = int(ID.split(':')[1].split('|')[0].split('-')[0])
        left = left - (left % window)
        right = int(ID.split(':')[1].split('|')[0].split('-')[1])
        right = right - (right % window)
        if math.fabs(right - left) < window:
            continue
        strand = ID.split('|')[1]
        newID = (chr,left,right,strand)
        if FragmentDict.has_key(newID):
            for i in range(1,len(fields)):
                FragmentDict[newID][i-1] += int(fields[i])
        else:
            FragmentDict[newID] = []
            for i in range(1,len(fields)):
                FragmentDict[newID].append(int(fields[i]))

    print 'finished inputting fragment counts'
    
    RP = 0
    IDs = FragmentDict.keys()
    IDs.sort()
    for (chr,left,right,strand) in IDs:
        RP += 1
        if RP % 1000000 == 0:
            print RP, 'merged fragments processed (out of ' + str(len(IDs)) + ')'
        ID = chr + ':' + str(left) + '-' + str(right) + '|' + strand
        outline = ID
        if doMinCounts:
            if max(FragmentDict[(chr,left,right,strand)]) < minCounts:
                continue
        for i in range(len(FragmentDict[(chr,left,right,strand)])):
            outline = outline + '\t' + str(FragmentDict[(chr,left,right,strand)][i])
        outfile.write(outline + '\n')

    print 'finished printing merged fragment counts'

run()
