##################################
#                                #
# Last modified 2017/01/17       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s input trueScoreFieldID threshold classificationFieldID "threshold":value|"class":fieldID outfilename' % sys.argv[0]
        print '\tNote: the classificationFieldID and class:fieldID parameters can include multiple fileds; then the sum of them will be used'
        sys.exit(1)

    input = sys.argv[1]
    trueFieldID = int(sys.argv[2])
    threshold = float(sys.argv[3])
    classFieldIDs = []
    for cID in sys.argv[4].split(','):
        classFieldIDs.append(int(cID))
    outfilename = sys.argv[6]

    doClassThreshold = False
    if sys.argv[5].split(':')[0] == 'threshold':
        doClassThreshold = True
        classThreshold = float(sys.argv[5].split(':')[1])
    if sys.argv[5].split(':')[0] == 'class':
        classFieldIDs2 = []
        for cID2 in sys.argv[5].split(':')[1].split(','):
            classFieldIDs2.append(int(cID2))

    DataList = []

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        trueScore = float(fields[trueFieldID])
        classScore = 0
        for cID in classFieldIDs:
            classScore += float(fields[cID])
        if doClassThreshold:
            if classScore >= classThreshold:
                classification = 1
            else:
                classification = 0
        else:
            classification = 0
            for cID2 in classFieldIDs2:
                classification += int(fields[cID2])
        DataList.append((classification,classScore,trueScore))

    P = 0.0
    N = 0.0

    DataList.sort()
    DataList.reverse()

    for (classification,classScore,trueScore) in DataList:
        if trueScore >= threshold:
            P+=1
        else:
            N+=1

# TPR = TP/P = TP/(TP + FN)
# FPR = FP/N = FP/(FP + TN)

    ROClist = []
    FPx = 0
    TPy = 0
    for (classification,classScore,trueScore) in DataList:
        if trueScore >= threshold:
            TPy += 1
        if trueScore < threshold:
            FPx += 1
        ROClist.append((FPx,TPy))
#        print FPx,TPy

    outfile = open(outfilename, 'w')
    outline = '#FPR\tTPR'
    outfile.write(outline + '\n')

    FPR = 0.0
    TPR = 0.0
    FNR = 0.0
    TNR = 0.0
    AUROC = 0.0

    outfile.write('0\t0\n')

#    for (classification,classScore,trueScore) in DataList:
#        if trueScore >= threshold and classification == 1:
#            TPR += 1/P
#            AUROC += (1/N)*TPR
#        if trueScore >= threshold and classification == 0:
#            FNR += 1/P
#        if trueScore < threshold and classification == 1:
#            FPR += 1/N
#        if trueScore < threshold and classification == 0:
#            TNR += 1/N
#        outline = str(FPR) + '\t' + str(TPR)
#        outfile.write(outline + '\n')

    for (FPx,TPy) in ROClist:
        FPR = FPx/N
        TPR = TPy/P
        i = ROClist.index((FPx,TPy))
        if ROClist[i][0] > ROClist[i-1][0]:
            AUROC += (1/N)*TPR
        outline = str(FPR) + '\t' + str(TPR)
        outfile.write(outline + '\n')

    outline = '#AUROC\t' + str(AUROC)
    print outline
    outfile.write(outline + '\n')

    outfile.close()

run()

