##################################
#                                #
# Last modified 2017/08/23       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy
import scipy.stats

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s input labelfields valuefields target outputfilename [-spearman]' % sys.argv[0]
        print '\tvaluefields format: either comma separated, or start:end (including start and end, 0-based)'
        print '\ttarget element IDs - comma separated'
        print '\tuse - for stdin'
        sys.exit(1)
    
    input = sys.argv[1]
    outfilename = sys.argv[5]
    outfile = open(outfilename, 'w')

    doSpearman = False
    if '-spearman' in sys.argv:
        doSpearman = True

    fields = sys.argv[2].split(',')
    labelFields=[]
    for f in fields:
        labelFields.append(int(f))
    labelFields.sort()

    print labelFields

    doAverage = False
    if '-average' in sys.argv:
        doAverage = True
        CountDict={}
        print 'will average over instances'

    valueFields=[]
    if ':' in sys.argv[3]:
        fields = sys.argv[3].split(':')
        start = int(fields[0])
        end = int(fields[1])
        for f in range(start,end+1):
            valueFields.append(f)
    else:
        fields = sys.argv[3].split(',')
        for f in fields:
            valueFields.append(int(f))
    valueFields.sort()

    print valueFields

    target = tuple(sys.argv[4].split(','))

    DataDict={}

    if input == '-':
        linelist  = sys.stdin
    else:
        linelist  = open(input)
    for line in linelist:
        fields=line.replace('\x00','').strip().split('\t')
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        label = []
        for ID in labelFields:
            label.append(fields[ID])
        label = tuple(label)
        DataDict[label]=[]
        for ID in valueFields:
            DataDict[label].append(float(fields[ID]))

    CorrelationList=[]

    for label in DataDict.keys():
        label = tuple(label)
        if doSpearman:
            CC = scipy.stats.spearmanr(DataDict[target],DataDict[label])[0]
        else:
            CC = numpy.corrcoef(DataDict[target],DataDict[label])[0,1]
        CorrelationList.append((CC,label))

    CorrelationList.sort()
    CorrelationList.reverse()
    for (CC, label) in CorrelationList:
        outline = str(CC)
        for l in label:
            outline  = outline + '\t' + l
        outfile.write(outline + '\n')

    outfile.close()
   
run()
