##################################
#                                #
# Last modified 2018/01/18       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s list_of_files merge_radius outfile [-binary] [-printBed radius]' % sys.argv[0]
        print '       list_of_files format: label <tab> filename <tab> chrFieldID <tab> peakFieldID <tab> scoreFieldID'
        print '       Note 1: multiple files can be associated with the same label'
        print '       Note 2: when peaks are merged the summit coordinate is assigned to the region with the higher score, whether RPM or something else'
        print '       Note 3: if the peak is specified explicitly, indicate the field with the fieldID. '
        print '               If you want the middle of the region, indicate that with "middle"'
        sys.exit(1)

    listoffiles = sys.argv[1]
    mergeradius = int(sys.argv[2])
    outfile = open(sys.argv[3],'w')

    doBinary = False
    if '-binary' in sys.argv:
        doBinary = True

    doPB = False
    if '-printBed' in sys.argv:
        doPB = True
        PBrad = int(sys.argv[sys.argv.index('-printBed') + 1])

    listoflines = open(listoffiles)
    PeakDict={}
    LabelsDict = {}
    for line1 in listoflines:
        fields1=line1.split('\n')[0].split('\t')
        label=fields1[0]
        LabelsDict[label] = 0
        file=fields1[1]
        print label, file
        chrID=int(fields1[2])
        peakID=fields1[3]
        scoreID=int(fields1[4])
#        linelist = open(file)
        if file.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + file
        elif file.endswith('.gz'):
            cmd = 'gunzip -c ' + file
        else:
            cmd = 'cat ' + file
        p = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = p.readline()
            if line == '':
                break
#        for line in linelist:
            if line.startswith('#'):
                continue
            fields=line.split('\n')[0].split('\t')
            chr = fields[chrID]
            if peakID == 'narrowPeak':
                peak = int(fields[1]) + int(fields[9])
            elif peakID == 'middle':
                peak = int((float(fields[chrID+1]) + float(fields[chrID+2]))/2)
            else:
                peak = int(fields[int(peakID)])
            score = float(fields[scoreID])
            if PeakDict.has_key(chr):
                pass
            else:
                PeakDict[chr] = {}
            if PeakDict[chr].has_key(peak):
                if PeakDict[chr][peak].has_key(label):
                    if PeakDict[chr][peak][label] < score:
                        PeakDict[chr][peak][label] = score
                else:
                    PeakDict[chr][peak][label] = score
            else:
                PeakDict[chr][peak] = {}
                PeakDict[chr][peak][label] = score

    print 'finished inputting peaks'

    chrKeys = PeakDict.keys()
    chrKeys.sort()

    labels = LabelsDict.keys()
    labels.sort()

    if doPB:
        outline = '#chr\tleft\tright'
    else:
        outline = '#chr::peak_position'
    for label in labels:
        outline = outline + '\t' + label
    outfile.write(outline + '\n')

    NewPeakDict={}
    for chr in chrKeys:
        print chr
        NewPeakDict[chr]={}
        peaks = PeakDict[chr].keys()
        peaks.sort()
        currentpeak = peaks[0]
        currentpeakDict=PeakDict[chr][currentpeak]
#        print currentpeak, PeakDict[chr][currentpeak]
#        print currentpeak, currentpeakDict
        for i in range(len(peaks)-1):
#            if chr == 'chr1':
#                print currentpeak, peaks[i], peaks[i+1], currentpeakDict, PeakDict[chr][peaks[i]], PeakDict[chr][peaks[i+1]]
            if peaks[i+1] - currentpeak <= mergeradius:
                HighestScore1 = 0 
                HighestScore2 = 0 
                for label in currentpeakDict.keys():
                    if HighestScore1 < currentpeakDict[label]:
                        HighestScore1 = currentpeakDict[label]
                for label in PeakDict[chr][peaks[i+1]].keys():
                    if HighestScore2 < PeakDict[chr][peaks[i+1]][label]:
                        HighestScore2 = PeakDict[chr][peaks[i+1]][label]
                if HighestScore2 > HighestScore1:
                    currentpeak = peaks[i+1]
                    for label in PeakDict[chr][peaks[i+1]].keys():
                        if currentpeakDict.has_key(label):
                            if PeakDict[chr][peaks[i+1]][label] > currentpeakDict[label]:
                                currentpeakDict[label] = PeakDict[chr][peaks[i+1]][label]
                        else:
                            currentpeakDict[label] = PeakDict[chr][peaks[i+1]][label]
                else:
                    for label in PeakDict[chr][peaks[i+1]].keys():
                        if currentpeakDict.has_key(label):
                            if PeakDict[chr][peaks[i+1]][label] > currentpeakDict[label]:
                                currentpeakDict[label] = PeakDict[chr][peaks[i+1]][label]
                        else:
                            currentpeakDict[label] = PeakDict[chr][peaks[i+1]][label]
            else:
                NewPeakDict[chr][currentpeak]={}
                for label in currentpeakDict.keys():
                    NewPeakDict[chr][currentpeak][label]=currentpeakDict[label]
                currentpeak = peaks[i+1]
                currentpeakDict=PeakDict[chr][peaks[i+1]]
        NewPeakDict[chr][currentpeak]={}
        for label in currentpeakDict.keys():
            NewPeakDict[chr][currentpeak][label]=currentpeakDict[label]

    print 'finished merging peaks'

    for chr in chrKeys:
        peaks =  NewPeakDict[chr].keys()
        peaks.sort()
        for peak in peaks:
            if doPB:
                outline = chr + '\t' + str(max(0,peak - PBrad)) + '\t' + str(peak + PBrad)
            else:
                outline = chr + '::' + str(peak)
            for label in labels:
                if NewPeakDict[chr][peak].has_key(label):
                    if doBinary:
                        outline = outline + '\t1'
                    else:
                        outline = outline + '\t' + str(NewPeakDict[chr][peak][label])
                else:
                    outline = outline + '\t0'
            outfile.write(outline + '\n')

    outfile.close()

run()
