##################################
#                                #
# Last modified 2018/02/01       #
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import random	
from sets import Set
import time

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s list_of_files outfilename' % sys.argv[0]
        print '\tinput format:'
        print '\t\tlabel <tab> path_to_tfdragonn-train'
        sys.exit(1)

    inputs = sys.argv[1]
    outfilename = sys.argv[2]

    DataDict = {}

    listoflines = open(inputs)
    for Lline in listoflines:
        Ffields = Lline.strip().split('\t')
        label = Ffields[0]
        filename = Ffields[1]
        linelist = open(filename)
        epoch = ''
        for line in linelist:
            if line.startswith("The best model's architecture and weights (from "):
                epoch = line.split('(from epoch ')[1].split(')')[0]
        if epoch == '':
            DataDict[label] = ('nan','nan','nan', 'nan','nan','nan', 'nan','nan','nan')
            continue
        linelist = open(filename)
        InEpoch = False
        for line in linelist:
            if line.startswith('Epoch'):
                currentEpoch = line.strip().split('Epoch ')[1].split(':')[0]
                if currentEpoch == epoch:
                    InEpoch = True
                else:
                    InEpoch = False
                continue
            elif line.startswith('Balanced Accuracy'):
                if InEpoch:
                    newline = line.replace('\t',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ').strip()
                    BACC = newline.split(' ')[2]
                    auROC = newline.split(' ')[4]
                    auPRC = newline.split(' ')[6]
                else:
                    pass
                continue
            elif line.startswith('Recall at 5'):
                if InEpoch:
                    newline = line.replace('\t',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ')
                    newline = newline.replace('  ',' ').strip()
                    RecallFDR5 = newline.split(' ')[10]
                    RecallFDR10 = newline.split(' ')[12]
                    RecallFDR25 = newline.split(' ')[14]
                    RecallFDR50 = newline.split(' ')[16]
                    NumPos = newline.split(' ')[19]
                    NumNeg = newline.split(' ')[22]
                else:
                    pass
                continue
            else:
                continue
        DataDict[label] = (BACC,auROC,auPRC,RecallFDR5,RecallFDR10,RecallFDR25,RecallFDR50,NumPos,NumNeg)

    labels = DataDict.keys()
    labels.sort()

    outfile = open(outfilename,'w')
    outline = '#label\tBACC\tauROC\tauPRC\tRecall@FDR=5%\tRecall@FDR=10%\tRecall@FDR=25%\tRecall@FDR=50%\tNumPos\tNumNeg\tImbalance'
    outfile.write(outline + '\n')

    for label in labels:
        (BACC,auROC,auPRC,RecallFDR5,RecallFDR10,RecallFDR25,RecallFDR50,NumPos,NumNeg) = DataDict[label]
        outline = label + '\t' + BACC + '\t' + auROC + '\t' + auPRC
        outline = outline + '\t' + RecallFDR5 + '\t' + RecallFDR10 + '\t' + RecallFDR25 + '\t' + RecallFDR50
        outline = outline + '\t' + NumPos + '\t' + NumNeg + '\t' + str(float(NumNeg)/float(NumPos))
        outfile.write(outline + '\n')

    outfile.close()

run()
