##################################
#                                #
# Last modified 2019/10/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
from sets import Set
import numpy as np
import gzip
# import pickle
import numpy as np
import pandas as pd
import umap
# from cycler import cycler
# import urllib
# from MulticoreTSNE import MulticoreTSNE as TSNE
# import MulticoreTSNE as TSNE

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s input.csv(.gz/.bz2) outfile [-tsv] [-quantNorm] [-topVariablePeaks N]' % sys.argv[0]
        print '\tinput format: one label field and then a matrix of values, genes in rows, cells in columns' 
        print '\tthe script will print to stdout' 
        sys.exit(1)
    
    input = sys.argv[1]
    outfilename = sys.argv[2]

    doTSV = False
    if '-tsv' in sys.argv:
        doTSV = True

    doQN = False
    if '-quantNorm' in sys.argv:
        doQN = True
        print 'will quantile normalize data'

    doTopVariablePeaks = False
    if '-topVariablePeaks' in sys.argv:
        doTopVariablePeaks = True
        topVP = int(sys.argv[sys.argv.index('-topVariablePeaks') + 1])
        print 'will only use the top', topVP, 'variable peaks'

    G = []
    X = []

    if doQN:
        D = {}
        L = 0
        X = []
        if input.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + input
        elif input.endswith('.gz'):
            cmd = 'zcat ' + input
        else:
            cmd = 'cat ' + input
        p = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = p.readline()
            if line == '':
                break
            if doTSV:
                fields = line.strip().split('\t')
            else:
                fields = line.strip().split(',')
            if line.startswith('#'):
                for ID in range(1,len(fields)):
                    G.append(fields[ID])
                    D[ID] = {}
                continue
#            values = []
#            for ID in range(1,len(fields)):
#                values.append(float(fields[ID]))
#            if doMin:
#                if max(values) < minValue:
#                    continue
#            if doMax:
#                if max(values) > maxValue:
#                    continue
            L += 1
            for ID in range(1,len(fields)):
                D[ID][L] = float(fields[ID])
        df = pd.DataFrame(D)
        rank_mean = df.stack().groupby(df.rank(method='first').stack().astype(int)).mean()
        E = df.rank(method='min').stack().astype(int).map(rank_mean).unstack()
        F = E.as_matrix()
        for Z in F:
            Zmean = np.mean(Z)
            Zstd = np.std(Z)
            ZNorm = (Z - Zmean)/Zstd
            X.append(ZNorm)
    else:
        L = 0
        if input.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + input
        elif input.endswith('.gz'):
            cmd = 'zcat ' + input
        else:
            cmd = 'cat ' + input
        p = os.popen(cmd, "r")
        line = 'line'
        while line != '':
            line = p.readline()
            if line == '':
                break
            if doTSV:
                fields = line.strip().split('\t')
            else:
                fields = line.strip().split(',')
            L+=1
            if line.startswith('#') or L==1:
                for ID in range(1,len(fields)):
                    G.append(fields[ID])
                continue
            D = []
            for ID in range(1,len(fields)):
                D.append(float(fields[ID]))
            X.append(D)

    if doTopVariablePeaks:
        Y = []
        Variances = []
        for D in X:
            V = np.var(D)
            Variances.append((V,D))
        Variances.sort()
        Variances.reverse()
        for i in range(topVP):
            Y.append(Variances[i][1])
    else:
         Y = X

    print len(Y)

    Y = np.array(X)
    Y = np.transpose(Y)

    reducer = umap.UMAP()
    embedding = reducer.fit_transform(Y)

    print embedding

    outfile = open(outfilename, 'w') 

    i = 0
    for (t1,t2) in embedding:
        outline = str(G[i]) + '\t' + str(t1) + '\t' + str(t2)
        outfile.write(outline + '\n')
        i += 1
   
run()
