##################################
#                                #
# Last modified 11/22/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s  datafilename IDs thresholds outfilename' % sys.argv[0]
        print 'format of IDs: either comma separated or start-end (including end)'
        print 'format of thresholds: comma-separated'
        print 'presence of a header line assumed'
        sys.exit(1)

    datafilename = sys.argv[1]
    IDfields=[]
    if '-' in sys.argv[2]:
        fields1=int(sys.argv[2].split('-')[0])
        fields2=int(sys.argv[2].split('-')[1])
        for ID in range(fields1,fields2+1):
            IDfields.append(ID)
    else:
        fields = sys.argv[2].split(',')
        for ID in fields:
            IDfields.append(int(ID))
    outfilename = sys.argv[4]

    thresholds=sys.argv[3]
    thresholdList=[]
    thresholds=thresholds.split(',')
    for t in thresholds:
        thresholdList.append(float(t))
    thresholdList=list(Set(thresholdList))

    thresholdList.sort()

    print thresholdList

    DataDict={}
    LabelDict={}
    lineslist  = open(datafilename)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line[0]=='#':
            fields=line.strip().split('\t')
            for ID in IDfields:
                LabelDict[ID]=fields[ID]
                print ID, fields[ID], 
            continue
        fields = line.strip().split('\t')
        DataDict[i]={}
        for ID in IDfields:
            DataDict[i][LabelDict[ID]]=float(fields[ID])

    outfile = open(outfilename, 'w')

    OutputDict={}
    for ID in LabelDict.keys():
        OutputDict[LabelDict[ID]]={}
        for t in thresholdList:
            OutputDict[LabelDict[ID]][t]=0

    for i in DataDict.keys():
        for t in thresholdList:
            bigger=[]
            for label in DataDict[i].keys():
                if DataDict[i][label] > t:
                    bigger.append(label)
                if len(bigger) > 1:
                    break
            if len(bigger) == 1:
                label = bigger[0]
                OutputDict[label][t] += 1

    outline='#'
    for t in thresholdList:
        outline=outline+'\t'+str(t)
    outfile.write(outline+'\n')
    for label in OutputDict.keys():
        outline=label
        for t in thresholdList:
            outline=outline+'\t'+str(OutputDict[label][t])
        outfile.write(outline+'\n')
        
    outfile.close()
        
run()

