##################################
#                                #
# Last modified 2020/11/18       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s inputfilename outfilename' % sys.argv[0]
        print 'Note: only use on the output of metaplot.py!!!!!'
        sys.exit(1)

    inputfilename = sys.argv[1]
    outputfilename = sys.argv[2]

    listoflines = open(inputfilename)
    i=0
    PosDict = {}
    Matrix = {}
    for line in listoflines:
        if line.startswith('#'):
            fields = line.split('\t')
            for i in range(1,len(fields)):
                pos = int(fields[i])
                Matrix[pos] = {}
                PosDict[i] = pos
            continue
        fields = line.split('\t')
        pos = int(fields[0])
        for i in range(1,len(fields)):
            Matrix[pos][PosDict[i]] = float(fields[i])

    DistanceScoresIntra = {}
    DistanceScoresInter = {}
    Diagonal = []

    positions = Matrix.keys()
    positions.sort()

    for pos1 in positions:
        for pos2 in positions:
            if pos1 > pos2:
                continue
            if pos1 == pos2 and pos1 < 0:
#                print pos1, pos2, Matrix[pos1][pos2]
                Diagonal.append(Matrix[pos1][pos2])
                continue
            if pos1 < 0 and pos2 < 0:
                distance = max(pos1,pos2) - min(pos1,pos2)
                score = Matrix[pos1][pos2]
#                print score, pos1, pos2, distance
                if DistanceScoresIntra.has_key(distance):
                    pass
                else:
                    DistanceScoresIntra[distance] = []
                DistanceScoresIntra[distance].append(score)
            if pos1 < 0 and pos2 > 0:
                distance = max(pos1,pos2) - min(pos1,pos2)
                score = Matrix[pos1][pos2]
                if DistanceScoresInter.has_key(distance):
                    pass
                else:
                    DistanceScoresInter[distance] = []
                DistanceScoresInter[distance].append(score)

#    print 'DistanceScoresInter'
#    print DistanceScoresInter
#    print 'DistanceScoresIntra'
#    print DistanceScoresIntra

    outfile = open(outputfilename, 'w')
    outline = '#distance\tn\tAverage_IntraTAD\tn\tAverage_InterTAD\tratio\tMaxMin_ratio\tCompactionScore_diagonal\tCompactionScore_off_diagonal\n'
    outfile.write(outline)

    distances = DistanceScoresInter.keys() + DistanceScoresIntra.keys()
    distances = list(Set(distances))
    distances.sort()

    TotalIntraList = []
    TotalInterList = []

    DiagonalScore = sum(Diagonal)/len(Diagonal)
    OffDiagonalScore = sum(DistanceScoresIntra[PosDict[2] - PosDict[1]])/len(DistanceScoresIntra[PosDict[2] - PosDict[1]])
#    print DiagonalScore, OffDiagonalScore

    ratios = []

    for D in distances:
#        print D
        outline = str(D)
        if DistanceScoresIntra.has_key(D):
            N1 = len(DistanceScoresIntra[D])
            S1 = sum(DistanceScoresIntra[D])/len(DistanceScoresIntra[D])
            TotalIntraList.append(S1)
        else:
            N1 = 'nan'
            S1 = 'nan'
        outline = outline + '\t' + str(N1) + '\t' + str(S1)
        if DistanceScoresInter.has_key(D):
            N2 = len(DistanceScoresInter[D])
            S2 = sum(DistanceScoresInter[D])/len(DistanceScoresInter[D])
            if DistanceScoresIntra.has_key(D):
                TotalInterList.append(S2)
        else:
            N1 = 'nan'
            S2 = 'nan'
        if DistanceScoresIntra.has_key(D) and DistanceScoresInter.has_key(D):
            maxmin = max(DistanceScoresIntra[D])/max(DistanceScoresInter[D])
            outline = outline + '\t' + str(N2) + '\t' + str(S2) + '\t' + str(S1/S2) + '\t' + str(maxmin)
            outline = outline + '\t' + str(S1/DiagonalScore) + '\t' + str(S1/OffDiagonalScore)
            ratios.append(S1/S2)
        elif DistanceScoresInter.has_key(D):
            outline = outline + '\t' + str(N2) + '\t' + str(S2) + '\tnan' + '\tnan' + '\tnan' + '\tnan'
        elif DistanceScoresIntra.has_key(D):
            outline = outline + '\t' + 'nan' + '\t' + 'nan' + '\tnan' + '\tnan' + '\t' + str(S1/DiagonalScore) + '\t' + str(S1/OffDiagonalScore)
        else:
            outline = outline + '\t' + 'nan' + '\t' + 'nan' + '\tnan' + '\tnan' + '\tnan' + '\tnan'
        outfile.write(outline + '\n')

    S1 = sum(TotalIntraList)/len(TotalIntraList)
    S2 = sum(TotalInterList)/len(TotalInterList)
    outline = 'Total_Ratio\tnan\t' + str(S1) + '\tnan\t' + str(S2) + '\t' + str(S1/S2) + '\tnan' + '\tnan' + '\tnan'
    outfile.write(outline + '\n')

    outline = 'Average_Ratio\tnan\t' + str(S1) + '\tnan\t' + str(S2) + '\t' + str(sum(ratios)/len(ratios)) + '\tnan' + '\tnan' + '\tnan'
    outfile.write(outline + '\n')

    outfile.close()

run()

