##################################
#                                #
# Last modified 2019/08/02       # 
#                                #
# Georgi Marinov                 #
	#                                # 
##################################

import sys
import string
import math
import numpy
import pandas as pd
from sklearn.decomposition import PCA

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s input labelfields valuefields outputfilprefix [-minMax Value] [-maxMax Value] [-quantNorm]' % sys.argv[0]
        print '\tvaluefields format: either comma separated, or start:end (including start and end, 0-based)'
        print '\tthe -min and -max options will filter out all lines the maximum value of which (within the specified fields) does not match the required values'
        sys.exit(1)
    
    input = sys.argv[1]
    outprefix = sys.argv[4]
    fields = sys.argv[2].split(',')
    labelFields=[]
    for f in fields:
        labelFields.append(int(f))
    labelFields.sort()

#    print labelFields

    doQN = False
    if '-quantNorm' in sys.argv:
        doQN = True
        print 'will quantile normalize data'

    valueFields=[]
    if ':' in sys.argv[3]:
        fields = sys.argv[3].split(':')
        start = int(fields[0])
        end = int(fields[1])
        for f in range(start,end+1):
            valueFields.append(f)
    else:
        fields = sys.argv[3].split(',')
        for f in fields:
            valueFields.append(int(f))
    valueFields.sort()

    doNTM = False
    if '-normToMean' in sys.argv:
        doNTM = True
        print 'will normalize to the mean instead of the stdv'

#    print valueFields

    doMin = False
    if '-minMax' in sys.argv:
        minValue = float(sys.argv[sys.argv.index('-minMax')+1])
        doMin = True
    doMax = False
    if '-maxMax' in sys.argv:
        maxValue = float(sys.argv[sys.argv.index('-maxMax')+1])
        doMax = True

    data = []

    DataSets = []

    if doQN:
        D = {}
        linelist = open(input)
        L = 0
        for line in linelist:
            fields=line.replace('\x00','').strip().split('\t')
            if line.startswith('#') or line.startswith('tracking_id'):
                if len(fields) < max(valueFields):
                    continue
                for ID in valueFields:
                    DataSets.append(fields[ID])
                    D[ID] = {}
                continue
            outline = ''
            for ID in labelFields:
                outline = outline + fields[ID] + '\t'
            values = []
            for ID in valueFields:
                values.append(float(fields[ID]))
            if doMin:
                if max(values) < minValue:
                    continue
            if doMax:
                if max(values) > maxValue:
                    continue
            L += 1
            for ID in valueFields:
                D[ID][L] = float(fields[ID])
        df = pd.DataFrame(D)
        rank_mean = df.stack().groupby(df.rank(method='first').stack().astype(int)).mean()
        E = df.rank(method='min').stack().astype(int).map(rank_mean).unstack()
        F = E.as_matrix()
        for X in F:
            Xmean = numpy.mean(X)
            Xstd = numpy.std(X)
            XNorm = (X - Xmean)/Xstd
            data.append(XNorm)
    else:
        linelist = open(input)
        for line in linelist:
            fields=line.replace('\x00','').strip().split('\t')
            if line.startswith('#') or line.startswith('tracking_id'):
                if len(fields) < max(valueFields):
                    continue
                for ID in valueFields:
                    DataSets.append(fields[ID])
                continue
            outline = ''
            for ID in labelFields:
                outline = outline + fields[ID] + '\t'
            values = []
            for ID in valueFields:
                values.append(float(fields[ID]))
            if doMin:
                if max(values) < minValue:
                    continue
            if doMax:
                if max(values) > maxValue:
                    continue
            X = numpy.array(values)
            Xmean = numpy.mean(X)
            Xstd = numpy.std(X)
            XNorm = (X - Xmean)/Xstd
            data.append(XNorm)

    X = numpy.array(data)

    pca = PCA(n_components=len(valueFields))
#    pca = PCA(n_components=2)
    pca.fit(X)

    outfile = open(outprefix + '.PCs', 'w')

    outline = '#dataset\tPC1\tPC2'
    outfile.write(outline + '\n')

    for i in range(len(valueFields)):
        outline = DataSets[i] + '\t' + str(pca.components_[0,i]) + '\t' + str(pca.components_[1,i])
        outfile.write(outline + '\n')
    
    outfile.close()

    outfile = open(outprefix + '.variance_explained', 'w')

#    print pca.explained_variance_

    for i in range(len(valueFields)):
        outline = str(pca.explained_variance_[i])
        outfile.write(outline + '\n')
    
    outfile.close()
   
run()
