##################################
#                                #
# Last modified 2018/05/19       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s  input outfilename [-rescale min max] [-upperTriangular]' % sys.argv[0]
        sys.exit(1)

    input = sys.argv[1]
    outfilename = sys.argv[2]

    doRescale = False
    if '-rescale' in sys.argv:
        doRescale = True
        minR = float(sys.argv[sys.argv.index('-rescale') + 1])
        maxR = float(sys.argv[sys.argv.index('-rescale') + 2])
        print 'will rescale scores to', minR, maxR

    doUT = False
    if '-upperTriangular' in sys.argv:
        doUT = True
        print 'will output an upper triangular matrix'

    DM = []
    X = []
    labels = []
    linelist = open(input)
    L=0
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        labels.append(fields[0])
        DMline = []
        for i in range(1,len(fields)):
            DMline.append(float(fields[i]))
        X.append(DMline)
        for i in range(L+2,len(fields)):
            DM.append(324-float(fields[i]))
        L+=1

#    print DM

    Z = linkage(DM, method='ward', metric='euclidean', optimal_ordering=True)
    clusters = fcluster(Z, 0, criterion='distance')

    CDict = {}
    for i in range(len(clusters)):
        C = clusters[i]
        if CDict.has_key(C):
            pass
        else:
            CDict[C] = []
        CDict[C].append(i)

    print clusters

    print CDict

    Cs = CDict.keys()
    Cs.sort()

    outfile = open(outfilename, 'w')

    outline = '#TF\t#TF'

    for C in Cs:
        for k in CDict[C]:
            outline = outline + '\t' + labels[k]
    outfile.write(outline+'\n')

    i=0
    for C in Cs:
        for k in CDict[C]:
            scores = X[k]
            outline = labels[k] + '\t' + labels[k]
            j=0
            for C2 in Cs:
                for k2 in CDict[C2]:
                    if doRescale: 
                        ns = minR + (scores[k2]/324.0)*(maxR - minR)
                        if doUT and j <= i:
                            outline = outline + '\t' + ''
                        else:
                            outline = outline + '\t' + str(ns)
                    else:
                        if doUT and j <= i:
                            outline = outline + '\t' + ''
                        else:
                            outline = outline + '\t' + str(scores[k2])
                    j+=1
            outfile.write(outline + '\n')
            i+=1
        
    outfile.close()


run()

