##################################
#                                #
# Last modified 2018/01/04       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import os
import string
import sys
import phate
import numpy as np
import pandas as pd
import sklearn.manifold
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def run():

    if len(sys.argv) < 3:
        print('usage: python3 %s matrix.csv.gz outprefix' % sys.argv[0])
        print('\tcsv.gz format: cells in rows, genes in columns, counts')
        print('\t\tlabel <tab> genes.tsv <tab> matrix.mtx')
        sys.exit(1)

    matrix = sys.argv[1]
    outprefix = sys.argv[2]

#    clusters = pd.read_csv("~/programs/PHATE/data/MAP.csv", header=None)
#    clusters.columns = pd.Index(['wells', 'clusters'])

    bmmsc = pd.read_csv(matrix, index_col=0)

#    C = clusters['clusters'] #using cluster labels from original publication

# library_size_normalize performs L1 normalization on each cell

    bmmsc_norm = phate.preprocessing.library_size_normalize(bmmsc)
    bmmsc_reduced = phate.preprocessing.pca_reduce(bmmsc_norm, n_components=20)
    pca = bmmsc_reduced[:,0:2] #first two PCA dimensions
    tsne = sklearn.manifold.TSNE().fit_transform(bmmsc_reduced)

    phate_operator = phate.PHATE(n_components=3, t=40, a=10, k=4, mds='classic', mds_dist='euclidean')
    Y_cmds = phate_operator.fit_transform(bmmsc_reduced)

    phate_operator = phate.PHATE(n_components=3, t=40, a=10, k=4, mds='metric', mds_dist='euclidean')
    Y_mmds = phate_operator.fit_transform(bmmsc_reduced)

    phate_operator = phate.PHATE(n_components=3, t=40, a=10, k=4, mds='nonmetric', mds_dist='euclidean')
    Y_nmmds = phate_operator.fit_transform(bmmsc_reduced)

    cells = []
    cmd = 'zcat ' + matrix
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        cells.append(line.strip().split(',')[0])

    outfile = open(outprefix + '.pca', 'w')
    i=0
    for cell in cells:         
        outline = cell + '\t' + str(pca[i][0]) + '\t' + str(pca[i][1])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    outfile = open(outprefix + '.tsne', 'w')
    i=0
    for cell in cells:         
        outline = cell + '\t' + str(tsne[i][0]) + '\t' + str(tsne[i][1])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    outfile = open(outprefix + '.phate_cmds', 'w')
    i=0
    for cell in cells:         
        outline = cell + '\t' + str(Y_cmds[i][0]) + '\t' + str(Y_cmds[i][1]) + '\t' + str(Y_cmds[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    outfile = open(outprefix + '.phate_mmds', 'w')
    i=0
    for cell in cells:         
        outline = cell + '\t' + str(Y_mmds[i][0]) + '\t' + str(Y_mmds[i][1]) + '\t' + str(Y_mmds[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    outfile = open(outprefix + '.phate_nmmds', 'w')
    i=0
    for cell in cells:         
        outline = cell + '\t' + str(Y_nmmds[i][0]) + '\t' + str(Y_nmmds[i][1]) + '\t' + str(Y_nmmds[i][2])
        outfile.write(outline + '\n')
        i+=1
    outfile.close()

    print('finished all steps')

run()