##################################
#                                #
# Last modified 2017/08/02       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from random import shuffle

def nearestNeighbor(cluster,clusters):

    (chr,left,right,strand,peak,counts,RPM,SI) = cluster
    NearestCluster = ''
    NearestDist = 1e20
    for CC in clusters.keys():
        (CCchr,CCleft,CCright,CCstrand,CCpeak) = clusters[CC]['coordinates']
        dist = math.fabs(CCpeak - peak)
        if dist <= NearestDist:
            NearestDist = dist
            NearestCluster = (CC,CCchr,CCleft,CCright,CCstrand,CCpeak)

    return NearestCluster

# def minDistLessThanThreshold(cluster,clusters,maxClusterSize,maxPeakDistance):
# 
#     (chr,left,right,strand,peak,counts,RPM,SI) = cluster
#     minimumDist = False
#     CloseClusters = 0
#     for CC in clusters.keys():
#         (CCchr,CCleft,CCright,CCstrand,CCpeak) = clusters[CC]['coordinates']
#         if CCleft == left and CCright == right and CCpeak == peak:
#              continue
#         if max(CCright,right) - min(CCleft,left) <= maxClusterSize or math.fabs(CCpeak-peak) <= maxPeakDistance:
#              CloseClusters += 1
#              print '...', cluster, clusters[CC]['coordinates']
#              if CloseClusters >= 2:
#                  minimumDist = True
#                  break
#     
#     return minimumDist

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s config mergeDistance maxClusterSize maxPeakDistance outputfilename' % sys.argv[0]
        print '\tconfig file format: <label> <tab> filename'
        sys.exit(1)
    
    config = sys.argv[1]
    maxClusterSize = int(sys.argv[2])
    maxPeakDistance = int(sys.argv[3])
    outfilename = sys.argv[4]

    ClusterDict = {}
    Labels = {} 

    filelinelist = open(config)
    for fileline in filelinelist:
        filefields = fileline.strip().split('\t')
        label = filefields[0]
        Labels[label] = 1
        file = filefields[1]
        linelist = open(file)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.split('\t')
            chr = fields[0]
            left = int(fields[1])
            right = int(fields[2])
            strand = fields[3]
            peak = int(fields[4])
            counts = float(fields[5])
            RPM = float(fields[6])
            SI = float(fields[7])
            gene = (fields[8].strip(),chr)
            cluster = (chr,left,right,strand,peak,counts,RPM,SI)
            if ClusterDict.has_key(gene):
                pass
            else:
                ClusterDict[gene] = {}
            if ClusterDict[gene].has_key(label):
                pass
            else:
                ClusterDict[gene][label] = []
            ClusterDict[gene][label].append(cluster)

    print 'finished inputting clusters'

    Ls = Labels.keys()
    Ls.sort()
    outfile = open(outfilename, 'w')
    outline = '#chr\tleft\tright\tstrand\tpeak\tgenes'
    for label in Ls:
        outline = outline + '\t' + label + '::peak'
        outline = outline + '\t' + label + '::counts'
        outline = outline + '\t' + label + '::RPM'
        outline = outline + '\t' + label + '::SI'
    outfile.write(outline + '\n')

    G=0        
    for gene in ClusterDict.keys():
        G+=1
        if G % 100 == 0:
            print G, 'genes processed'
        clusters = {}
        for label in ClusterDict[gene].keys():
            ClusterDict[gene][label].sort()
        for label in ClusterDict[gene].keys():
#            print label
            if gene[0] == '':
                print 'processing clusters without an associated gene:', gene[1], label
            for cluster in ClusterDict[gene][label]:
                (chr,left,right,strand,peak,counts,RPM,SI) = cluster
                NN = nearestNeighbor(cluster,clusters)
                if NN == '':
                    clusters[1] = {}
                    clusters[1]['coordinates'] = (chr,left,right,strand,peak)
                    clusters[1]['maxCounts'] = RPM
                    clusters[1][label] = (peak,counts,RPM,SI)
                    continue
                (CC,NNchr,NNleft,NNright,NNstrand,NNpeak) = NN
#                print 'NN', NN
                if math.fabs(NNpeak - peak) < maxPeakDistance or max(NNright,right) - min(NNleft,left) <= maxClusterSize:
                    if RPM >= clusters[CC]['maxCounts']:
                        newpeak = peak
                        newRPM = RPM
                    else:
                        newpeak = NNpeak
                        newRPM = clusters[CC]['maxCounts']
                    clusters[CC]['coordinates'] = (chr,min(NNleft,left),max(NNright,right),strand,newpeak)
                    clusters[CC]['maxCounts'] = newRPM
                    clusters[CC][label] = (peak,counts,RPM,SI)
#                    print 'CC', CC, clusters[CC]['coordinates'], math.fabs(NNpeak - peak), max(NNright,right) - min(NNleft,left)
                else:
                    newCC = max(clusters.keys())+1
                    clusters[newCC] = {}
                    clusters[newCC]['coordinates'] = (chr,left,right,strand,peak)
                    clusters[newCC]['maxCounts'] = RPM
                    clusters[newCC][label] = (peak,counts,RPM,SI)
#                    print 'newCC', newCC, clusters[newCC]['coordinates'], math.fabs(NNpeak - peak), max(NNright,right) - min(NNleft,left)
        for C in clusters.keys():
            (chr,left,right,strand,peak) = clusters[C]['coordinates']
            outline = chr + '\t' + str(left) + '\t' + str(right) + '\t' + strand + '\t' + str(peak) + '\t' + gene[0]
            for label in Ls:
                if clusters[C].has_key(label):
                    (peak,counts,RPM,SI) = clusters[C][label]
                    outline = outline + '\t' + str(peak)
                    outline = outline + '\t' + str(counts)
                    outline = outline + '\t' + str(RPM)
                    outline = outline + '\t' + str(SI)
                else:
                    outline = outline + '\t' + '.'
                    outline = outline + '\t' + str(0)
                    outline = outline + '\t' + str(0)
                    outline = outline + '\t' + str(0)
            outfile.write(outline + '\n')

    outfile.close()
            
run()
