##################################
#                                #
# Last modified 2017/04/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
from scipy.cluster.hierarchy import linkage, dendrogram

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s datafile linkage_method distance_metric outfile' % sys.argv[0]
        print '\tNote: the input file can be .bz2 or gz'
        print '\tNote: the script assumes that the left most column contains entry IDs'
        sys.exit(1)
    
    input = sys.argv[1]
    L = sys.argv[2]
    D = sys.argv[3]
    outfilename = sys.argv[4]

    DataDict = {}
    DataMatrix = []

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('.gz'):
        cmd = 'zcat ' + input
    else:
        cmd = 'cat ' + input
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    labels = []
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        i+=1
        if i % 10000 == 0:
            print i
        fields = line.strip().split('\t')
        ID = fields[0]
        values = []
        labels.append(ID)
        for j in range(1,len(fields)):
            v = float(fields[j])
            values.append(v)
        DataDict[ID] = values
        DataMatrix.append(values)

    DataMatrix = np.array(DataMatrix)

    Z = linkage(DataMatrix, method=L, metric=D)
    D1 = dendrogram(Z, labels = labels, color_threshold=None, no_plot=True)
   
    ordering = D1['ivl']

    outfile = open(outfilename,'w')   
    outline = '#'
    for ID in ordering:
        outline = outline + '\t' + ID
    outfile.write(outline + '\n')
   
    for ID1 in ordering:
        outline = ID1
        for ID2 in ordering:
            CC = np.corrcoef(DataDict[ID1],DataDict[ID2])[0,1]
            outline = outline + '\t' + str(CC)
        outfile.write(outline + '\n')

    outfile.close()

run()
