##################################
#                                #
# Last modified 03/12/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string

def FractionWrongMI(GeneDict,ID1,ID2,minFPKM,maxFPKM,minNI,maxNI):

    Genes = 0.0
    Transcripts = 0.0
    WrongMIGenes = 0.0

    FPT = 0
    FNT = 0
    FPG = {}
    FNG = {}

    for gene in GeneDict.keys():
        FPKMG1 = sum(GeneDict[gene]['FPKM'][ID1])
        FPKMG2 = sum(GeneDict[gene]['FPKM'][ID2])
        NI = len(GeneDict[gene]['FPKM'][ID1])
        if FPKMG1 < minFPKM or FPKMG1 >= maxFPKM:
            continue
        if NI < minNI or NI >= maxNI:
            continue
        if FPKMG1 == 0:
            continue
        Genes += 1
        maxFPKMT_index_1 = GeneDict[gene]['FPKM'][ID1].index(max(GeneDict[gene]['FPKM'][ID1]))
        maxFPKMT_index_2 = GeneDict[gene]['FPKM'][ID2].index(max(GeneDict[gene]['FPKM'][ID2]))
        if maxFPKMT_index_1 != maxFPKMT_index_2:
            WrongMIGenes += 1
        for i in range(len(GeneDict[gene]['FPKM'][ID1])):
            Transcripts += 1
            Theta1 = GeneDict[gene]['FPKM'][ID1][i]/FPKMG1
            if FPKMG2 == 0:
                Theta2 = 0
            else:
                Theta2 = GeneDict[gene]['FPKM'][ID2][i]/FPKMG2
            if Theta1 > 0.05 and Theta2 < 0.001:
                FNG[gene] = 1
                FNT += 1
            if Theta2 > 0.05 and Theta1 == 0:
                FPG[gene] = 1
                FPT += 1

    print Genes

    return (WrongMIGenes/Genes,len(FPG.keys())/Genes,len(FNG.keys())/Genes,FPT/Transcripts,FNT/Transcripts)

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s table geneIDfield(s) transcriptIDfield(s) minGeneFPKM maxGeneFPKM minNumberIsoforms maxNumberIsoforms config outfilename' % sys.argv[0]
        print '\tconfig file format:'
        print '\tlabel\tfieldID1\tfieldID2'
        print '\tthe minGeneFPKM and maxGeneFPKM refer to the gene-level FPKMs which will be calculated from the sum of the transcript-levle FPKMs in fieldID1'
        print '\t < minNI and >= maxNI'
        sys.exit(1)

    input = sys.argv[1]
    geneIDfields = []
    for ID in sys.argv[2].split(','):
        geneIDfields.append(int(ID))
    trancriptIDfields = []
    for ID in sys.argv[3].split(','):
        trancriptIDfields.append(int(ID))
    minFPKM = float(sys.argv[4])
    maxFPKM = float(sys.argv[5])
    minNI = int(sys.argv[6])
    maxNI = int(sys.argv[7])
    config = sys.argv[8]
    outputfilename = sys.argv[9]

    RMSEDict = {}
    WantedFields = {}

    linelist = open(config)
    for line in linelist:
        fields = line.strip().split('\t')
        label = fields[0]
        ID1 = int(fields[1])
        ID2 = int(fields[2])
        RMSEDict[label] = (ID1,ID2)
        WantedFields[ID1] = 1
        WantedFields[ID2] = 1

    GeneDict = {}

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = []
        for ID in geneIDfields:
            gene.append(fields[ID])
        gene = tuple(gene)
        if GeneDict.has_key(gene):
            pass
        else:
            GeneDict[gene] = {}
            GeneDict[gene]['FPKM']={}
            GeneDict[gene]['isoforms']=[]
        transcript = []
        for ID in trancriptIDfields:
            transcript.append(fields[ID])
        transcript = tuple(transcript)
        GeneDict[gene]['isoforms'].append(transcript)
        for ID in range(len(fields)):
            if WantedFields.has_key(ID):
                if GeneDict[gene]['FPKM'].has_key(ID):
                    pass
                else:
                    GeneDict[gene]['FPKM'][ID] = []
                GeneDict[gene]['FPKM'][ID].append(float(fields[ID]))

    labels = RMSEDict.keys()
    labels.sort()

    outfile = open(outputfilename,'w')
    outfile.write('#Name\tFraction_Wrong_MI\tGenesWithFPIsoforms,0.05Theta\tGenesWithFNIsoforms,0.05Theta\tFPIsoforms,0.05Theta\tfnIsoforms,0.05Theta\t\n')

    for label in labels:
        (ID1,ID2) = RMSEDict[label]
        (FWMI,FPG,FNG,FPT,FNT) = FractionWrongMI(GeneDict,ID1,ID2,minFPKM,maxFPKM,minNI,maxNI)
        outline = label + '\t' + str(FWMI) + '\t' + str(FPG) + '\t' + str(FNG) + '\t' + str(FPT) + '\t' + str(FNT)
        outfile.write(outline + '\n')

    outfile.close()

run()

