##################################
#                                #
# Last modified 03/11/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string

def CalculateRMSEFMIdiff(list1,list2,minFPKM,maxFPKM):

    RMSE = 0
    numT = 0

    FMIdiff = 0

    Genes = {}

    for i in range(len(list1)):
        (FMI1,FPKMG1,gene) = list1[i]
        if FPKMG1 < minFPKM or FPKMG1 >= maxFPKM:
            continue
        (FMI2,FPKMG2,gene) = list2[i]
        Genes[gene] = 1
        numT += 1
        RMSE += (FMI1 - FMI2)*(FMI1 - FMI2)
        FMIdiff += math.fabs(FMI1 - FMI2)

    numGenes = len(Genes.keys())

    print numGenes
    print numT

    RMSEG = math.sqrt(RMSE/numGenes)
    RMSET = math.sqrt(RMSE/numT)
    FMIdiff = FMIdiff/numGenes

    return (RMSEG,RMSET,FMIdiff)

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s table geneIDfield(s) transcriptIDfield(s) minGeneFPKM maxGeneFPKM minNumberIsoforms maxNumberIsoforms config outfilename' % sys.argv[0]
        print '\tconfig file format:'
        print '\tlabel\tfieldID1\tfieldID2'
        print '\tthe minGeneFPKM and maxGeneFPKM refer to the gene-level FPKMs which will be calculated from the sum of the transcript-levle FPKMs in fieldID1'
        print '\t < minNI and >= maxNI'
        sys.exit(1)

    input = sys.argv[1]
    geneIDfields = []
    for ID in sys.argv[2].split(','):
        geneIDfields.append(int(ID))
    trancriptIDfields = []
    for ID in sys.argv[3].split(','):
        trancriptIDfields.append(int(ID))
    minFPKM = float(sys.argv[4])
    maxFPKM = float(sys.argv[5])
    minNI = int(sys.argv[6])
    maxNI = int(sys.argv[7])
    config = sys.argv[8]
    outputfilename = sys.argv[9]

    RMSEDict = {}
    WantedFields = {}
    DataDict = {}

    linelist = open(config)
    for line in linelist:
        fields = line.strip().split('\t')
        label = fields[0]
        ID1 = int(fields[1])
        ID2 = int(fields[2])
        RMSEDict[label] = (ID1,ID2)
        DataDict[ID1] = []
        DataDict[ID2] = []
        WantedFields[ID1] = 1
        WantedFields[ID2] = 1

    GeneDict = {}

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = []
        for ID in geneIDfields:
            gene.append(fields[ID])
        gene = tuple(gene)
        if GeneDict.has_key(gene):
            pass
        else:
            GeneDict[gene] = {}
            GeneDict[gene]['FPKM']={}
            GeneDict[gene]['isoforms']=[]
        transcript = []
        for ID in trancriptIDfields:
            transcript.append(fields[ID])
        transcript = tuple(transcript)
        GeneDict[gene]['isoforms'].append(transcript)
        for ID in range(len(fields)):
            if WantedFields.has_key(ID):
                if GeneDict[gene]['FPKM'].has_key(ID):
                    pass
                else:
                    GeneDict[gene]['FPKM'][ID] = 0
                GeneDict[gene]['FPKM'][ID] += float(fields[ID])

    linelist = open(input)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = []
        for ID in geneIDfields:
            gene.append(fields[ID])
        gene = tuple(gene)
        NI = len(GeneDict[gene]['isoforms'])
        if NI < minNI or NI >= maxNI:
            continue
        for ID in range(len(fields)):
            if WantedFields.has_key(ID):
                FPKMG = GeneDict[gene]['FPKM'][ID]
                if FPKMG == 0:
                    FMI = 0
                    if float(fields[ID])> 0:
                        print 'non-zero transcript FPKM for a gene with 0 FPKM, exiting'
                        print gene
                        sys.exit(1)
                else:
                    FMI = float(fields[ID])/FPKMG
                DataDict[ID].append((FMI,FPKMG,gene))

    labels = RMSEDict.keys()
    labels.sort()

    outfile = open(outputfilename,'w')
    outfile.write('#Name\tRMSEG\tRMSET\tFMIdiff\n')

    for label in labels:
        (ID1,ID2) = RMSEDict[label]
        (RMSEG,RMSET,FMIdiff) = CalculateRMSEFMIdiff(DataDict[ID1],DataDict[ID2],minFPKM,maxFPKM)
        outline = label + '\t' + str(RMSEG) + '\t' + str(RMSET) + '\t' + str(FMIdiff)
        outfile.write(outline + '\n')

    outfile.close()

run()

