##################################
#                                #
# Last modified 2018/06/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import os
import string
import sys
import time
import phate
import numpy as np
import pandas as pd
import sklearn.manifold
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def run():

    if len(sys.argv) < 2:
        print('usage: python3 %s matrix.csv.gz outprefix [-FBC-split]' % sys.argv[0])
        print('\tcsv.gz format: cells in rows, genes in columns, counts')
        print('\t\tlabel <tab> genes.tsv <tab> matrix.mtx')
        sys.exit(1)

    matrix = sys.argv[1]
    outprefix = sys.argv[2]

    doFBCSplit = False
    if '-FBC-split' in sys.argv:
        doFBCSplit = True
        print 'will split cell IDs from barcodes'

    print 'inputting cells'

    bmmsc = pd.read_csv(matrix, index_col=0)

    print 'finished inputting cells'

    bmmsc_norm = phate.preprocessing.library_size_normalize(bmmsc)
    bmmsc_norm = np.sqrt(bmmsc_norm)

    cells = []
    cmd = 'zcat ' + matrix
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        cells.append(line.strip().split(',')[0])

    print 'finished parsing cells'

    start = time.time()
    phate_operator = phate.PHATE(n_components=3, a=15, k=4, mds='classic', mds_dist='euclidean')
    Y_cmds = phate_operator.fit_transform(bmmsc_norm)
    end = time.time()
    print("Embedded CMDS PHATE in {:.2f} seconds.".format(end-start))

    outfile = open(outprefix + '.phate_cmds', 'w')
    i=0
    for cell in cells:         
        if doFBCSplit:
            outline = cell.replace('-','\t') + '\t' + str(Y_cmds[i][0]) + '\t' + str(Y_cmds[i][1]) + '\t' + str(Y_cmds[i][2])
        else:
            outline = cell + '\t' + str(Y_cmds[i][0]) + '\t' + str(Y_cmds[i][1]) + '\t' + str(Y_cmds[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    start = time.time()
    phate_operator = phate.PHATE(n_components=3, a=15, k=4, mds='metric', mds_dist='euclidean')
    Y_mmds = phate_operator.fit_transform(bmmsc_norm)
    end = time.time()
    print("Embedded MMDS PHATE in {:.2f} seconds.".format(end-start))

    outfile = open(outprefix + '.phate_mmds', 'w')
    i=0
    for cell in cells:         
        if doFBCSplit:
            outline = cell.replace('-','\t') + '\t' + str(Y_mmds[i][0]) + '\t' + str(Y_mmds[i][1]) + '\t' + str(Y_mmds[i][2])
        else:
            outline = cell + '\t' + str(Y_mmds[i][0]) + '\t' + str(Y_mmds[i][1]) + '\t' + str(Y_mmds[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    start = time.time()
    phate_operator = phate.PHATE(n_components=3, a=15, k=4, mds='nonmetric', mds_dist='euclidean')
    Y_nmmds = phate_operator.fit_transform(bmmsc_norm)
    end = time.time()
    print("Embedded NMMDS PHATE in {:.2f} seconds.".format(end-start))

    outfile = open(outprefix + '.phate_nmmds', 'w')
    i=0
    for cell in cells:
        if doFBCSplit:
            outline = cell.replace('-','\t') + '\t' + str(Y_nmmds[i][0]) + '\t' + str(Y_nmmds[i][1]) + '\t' + str(Y_nmmds[i][2])
        else:
            outline = cell + '\t' + str(Y_nmmds[i][0]) + '\t' + str(Y_nmmds[i][1]) + '\t' + str(Y_nmmds[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    start = time.time()
    pca_operator = sklearn.decomposition.PCA(n_components=3)
    Y_pca = pca_operator.fit_transform(np.array(bmmsc_norm))
    end = time.time()
    print("Embedded PCA in {:.2f} seconds.".format(end-start))

    outfile = open(outprefix + '.pca', 'w')
    i=0
    for cell in cells:         
        if doFBCSplit:
            outline = cell.replace('-','\t') + '\t' + str(Y_pca[i][0]) + '\t' + str(Y_pca[i][1]) + '\t' + str(Y_pca[i][2])
        else:
            outline = cell + '\t' + str(Y_pca[i][0]) + '\t' + str(Y_pca[i][1]) + '\t' + str(Y_pca[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    start = time.time()
    tsne_operator = sklearn.manifold.TSNE(n_components=2)
    Y_tsne = sklearn.manifold.TSNE().fit_transform(bmmsc_norm)
    end = time.time()
    print("Embedded t-SNE in {:.2f} seconds.".format(end-start))

    outfile = open(outprefix + '.tsne', 'w')
    i=0
    for cell in cells:         
        if doFBCSplit:
            outline = cell.replace('-','\t') + '\t' + str(Y_tsne[i][0]) + '\t' + str(Y_tsne[i][1])
        else:
            outline = cell + '\t' + str(Y_tsne[i][0]) + '\t' + str(Y_tsne[i][1])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    print('finished all steps')

run()