##################################
#                                #
# Last modified 2018/01/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s input labelfields valuefields outputfilename [-minMax Value] [-maxMax Value] [-normToMean]' % sys.argv[0]
        print '\tvaluefields format: either comma separated, or start:end (including start and end, 0-based)'
        print '\tthe -min and -max options will filter out all lines the maximum value of which (within the specified fields) does not match the required values'
        sys.exit(1)
    
    input = sys.argv[1]
    outfilename = sys.argv[4]
    outfile = open(outfilename, 'w')

    fields = sys.argv[2].split(',')
    labelFields=[]
    for f in fields:
        labelFields.append(int(f))
    labelFields.sort()

    print labelFields

    doNTM = False
    if '-normToMean' in sys.argv:
        doNTM = True
        print 'will normalize to the raw mean after the quantile normalization'

    valueFields=[]
    if ':' in sys.argv[3]:
        fields = sys.argv[3].split(':')
        start = int(fields[0])
        end = int(fields[1])
        for f in range(start,end+1):
            valueFields.append(f)
    else:
        fields = sys.argv[3].split(',')
        for f in fields:
            valueFields.append(int(f))
    valueFields.sort()

    print valueFields

    doMin = False
    if '-minMax' in sys.argv:
        minValue = float(sys.argv[sys.argv.index('-minMax')+1])
        doMin = True
    doMax = False
    if '-maxMax' in sys.argv:
        maxValue = float(sys.argv[sys.argv.index('-maxMax')+1])
        doMax = True

    DataLabels = []
    DataValues = []
    ColumnValueDict = {}
    for i in range(len(valueFields)):
        ColumnValueDict[i] = []

    linelist = open(input)
    for line in linelist:
        fields=line.replace('\x00','').strip().split('\t')
        if line.startswith('#') or line.startswith('tracking_id'):
            if len(fields) < max(valueFields):
                continue
            outline = '#'
            for ID in labelFields:
                outline = outline + fields[ID] + '\t'
            for ID in valueFields:
                outline = outline + fields[ID] + '\t'
            outfile.write(outline.strip()+'\n')
            continue
        values = []
        for ID in valueFields:
            values.append(float(fields[ID]))
        if doMin:
            if max(values) < minValue:
                continue
        if doMax:
            if max(values) > maxValue:
                continue
        DataValues.append(values)
        L = []
        for ID in labelFields:
            L.append(fields[ID])
        DataLabels.append(L)
        for i in range(len(values)):
            ColumnValueDict[i].append(values[i])

    for i in range(len(valueFields)):
        ColumnValueDict[i].sort()

    sortMatrix = []
    rankMeans = []
    for i in range(len(DataValues)):
        sortV = []
        for j in range(len(valueFields)):
            sortV.append(ColumnValueDict[j][i])
        sortMatrix.append(numpy.array(sortV))
        rankMeans.append(numpy.mean(sortMatrix[i]))

#    print 'sortMatrix'
#    print sortMatrix
#    print 'rankMeans'
#    print rankMeans

    rankMatrix = []
    for values in DataValues:
        rankV = []
        for i in range(len(values)):
            rank = ColumnValueDict[i].index(values[i])
            rankV.append(rank)
        rankMatrix.append(rankV)

#    print rankMatrix

    for i in range(len(DataValues)):
        outline = ''
        for k in range(len(labelFields)):
            outline = outline + DataLabels[i][k] + '\t'
        QVals = []
        for j in range(len(valueFields)):
            rank = rankMatrix[i][j]
            QV = rankMeans[rank]
            QVals.append(QV)
        if doNTM:
            arr = numpy.array(QVals)
            arrMean = numpy.mean(arr)
            arrStd = numpy.std(arr)
            for QV in QVals:
                normQ = (QV - arrMean)/arrMean
                outline = outline + str(normQ) + '\t'
        else:
            for QV in QVals:
                outline = outline + str(QV) + '\t'
        outfile.write(outline.strip() + '\n')

    outfile.close()
   
run()
