##################################
#                                #
# Last modified 04/12/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s inputfilename parameters threshold outputfilename' % sys.argv[0]
        print '	input file format: '
        print '	#Name	ReadNumber1	ReadNumber2	.... ReadNumber3'
        print '	gene	value1	value2	..... final value '
        print '	parameters format: RPKM1,RPKM2,RPKM3...RPKMN (genes with values less than the first one will be omitted)'
        print '	threshold: maximum % difference from final value, for example 0.05)'

        sys.exit(1)

    inputfilename = sys.argv[1]
    parameters = sys.argv[2]
    threshold = float(sys.argv[3])
    outputfilename = sys.argv[4]
    ParamList=[]
    ParamList.append(0.0)
    NomoDict={}
    fields=parameters.split(',')
    for p in fields:
        ParamList.append(float(p))
        NomoDict[float(p)]=0

    ParamList = list(Set(ParamList))
    ParamList.sort()

    print ParamList

    outfile = open(outputfilename, 'w')

    lineslist = open(inputfilename)
    for line in lineslist:
        fields=line.strip().split('\t')
        if line[0]=='#':
            ReadNumberList=fields[1:len(fields)]
            PosToReadNumberDict={}
            for N in fields:
                PosToReadNumberDict[fields.index(N)]=N
            StatsDict={}
            for RN in ReadNumberList:
                StatsDict[RN]=[]
            continue
        try:
            finalValue=float(fields[-1])
        except:
            print 'skipping', fields
            continue
        if finalValue >= max(ParamList): 
            param=max(ParamList)
        else:
            for p in ParamList:
                if finalValue >= p and finalValue < ParamList[ParamList.index(p)+1]:
                    param=p
                    break
        try:
            NomoDict[param]+=1
        except:
            print param
            if NomoDict.has_key(param):
                NomoDict[param]+=1
            else:
                NomoDict[param]=1
        for value in fields[1:len(fields)]:
            RN=PosToReadNumberDict[fields.index(value)]
            if (math.fabs(float(value)-finalValue))/(finalValue+0.001) <= threshold:
                StatsDict[RN].append((param,1))
            else:
                StatsDict[RN].append((param,0))
            
    outline = '#Group'
    outline2 = '#Number'
    ParamList.sort()

    print NomoDict
    print ParamList

    for P in ParamList:
        outline=outline+'\t'+str(P)
        outline2=outline2+'\t'+str(NomoDict[P])
    outfile.write(outline+'\n')
    outfile.write(outline2+'\n')

    keys=StatsDict.keys()
    keys.sort()
    for RN in keys:
        print RN
        FinalStatsDict={}
        for P in ParamList:
            FinalStatsDict[P]={}
            FinalStatsDict[P][1]=0
            FinalStatsDict[P][0]=0
        for (param,B) in StatsDict[RN]:
            FinalStatsDict[param][B]+=1
        outline=RN
        for P in ParamList:
            try:
                score=FinalStatsDict[P][1]/(FinalStatsDict[P][1]+FinalStatsDict[P][0]+0.0)
            except:
                score = 0.00000
            score=str(score)
            score=score.split('.')[0]+'.'+score.split('.')[1][0:2]
            outline=outline+'\t'+str(score)
        outfile.write(outline+'\n')

    outfile.close()

run()

