##################################
#                                #
# Last modified 11/02/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s <list-of-geneIDs> <gtf (with class codes)> <FPKM table> <transcriptID field ID> <FPKM threshold> <max isoform order> <FPKM fields> <outfilename>' % sys.argv[0]
        print '       fields and thresholds comma-separated'
        sys.exit(1)

    genes = sys.argv[1]
    gtf = sys.argv[2]
    expr = sys.argv[3]
    IDfield = int(sys.argv[4])
    minFPKM = float(sys.argv[5])
    order = int(sys.argv[6])
    FPKMfields =[]
    fields = sys.argv[7].split(',')
    for v in fields:
        FPKMfields.append(int(v))
    outfilename = sys.argv[8]

    linelist = open(genes)
    GeneDict={}  
    TranscriptToGeneDict={}
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        geneID=fields[0]
        GeneDict[geneID]={}

    linelist = open(gtf)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        transcriptID=fields[8].split('transcript_id "')[1].split('"')[0]
        geneID=fields[8].split('gene_id "')[1].split('"')[0]
        if GeneDict.has_key(geneID):
            pass
        else:
            continue
        GeneDict[geneID][transcriptID]={}
        GeneDict[geneID][transcriptID]
        GeneDict[geneID][transcriptID]['expression']={}
        GeneDict[geneID][transcriptID]['major']={}
        if 'class_code' in fields[8]:
            pass
        else:
            if geneID.startswith('ENS'):
                pass
            else:
                continue
        if 'class_code' in fields[8]:
            class_code = fields[8].split('class_code "')[1].split('"')[0]
        else:
            class_code = '='
        if class_code == '=':
            status='known'
        if class_code == 'j':
            status='novel'
        GeneDict[geneID][transcriptID]['status']=status
        TranscriptToGeneDict[transcriptID]=geneID

    print len(GeneDict.keys())

    IDtoLableDict={}
    linelist = open(expr)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('#'):
            for ID in FPKMfields:
                IDtoLableDict[ID]=fields[ID]
            continue
        if line.startswith('tracking_id'):
            continue
        transcriptID=fields[IDfield]
        if TranscriptToGeneDict.has_key(transcriptID):
            geneID=TranscriptToGeneDict[transcriptID]
        else:
            continue
        for ID in FPKMfields:
            if 'FAIL' in fields[ID]:
                GeneDict[geneID][transcriptID]['expression'][ID]='FAIL'
            else:
                GeneDict[geneID][transcriptID]['expression'][ID] = float(fields[ID].split(',')[0])

    MajorIsoformsDict={}

    outfile=open(outfilename,'w')

    FMIDiCT={}
    FMIDiCT['total']={}
    FMIDiCT['known']={}
    FMIDiCT['novel']={}

    outline = ''
    for i in range(1,order+1):
        outline = outline + 'Total_' + str(i) + '\t'
        FMIDiCT['total'][i]=[]
    for i in range(1,order+1):
        outline = outline + 'Known_' + str(i) + '\t'
        FMIDiCT['known'][i]=[]
    for i in range(1,order+1):
        outline = outline + 'Novel_' + str(i) + '\t'
        FMIDiCT['novel'][i]=[]
    outfile.write(outline.strip() + '\n')

    i=0
    for geneID in GeneDict.keys():
        i+=1
        if i % 1000 == 0:
            print str(i)
        outline = geneID
        MajIsoDict={}
        for ID in FPKMfields:
            ExpressionList=[]
            FAIL = False
            TotalFPKM=0
            for transcriptID in GeneDict[geneID].keys():
                if len(GeneDict[geneID][transcriptID]['expression']) == 0:
                    continue
                if GeneDict[geneID][transcriptID]['expression'][ID]=='FAIL':
                    FPKM=0
                    FAIL = True
                    break
                else:
                    ExpressionList.append((GeneDict[geneID][transcriptID]['expression'][ID],transcriptID))
                    TotalFPKM += GeneDict[geneID][transcriptID]['expression'][ID]
            if FAIL:
                continue
            if len(ExpressionList) == 1:
                continue
            if TotalFPKM < minFPKM:
                continue
            ExpressionList.sort()
            ExpressionList.reverse()
            MIFPKM = ExpressionList[0][0]
            for i in range(min(order,len(ExpressionList))):
                (FPKM,transcriptID) = ExpressionList[i]
                FMI = FPKM/MIFPKM
                status = GeneDict[geneID][transcriptID]['status']
                FMIDiCT['total'][i+1].append(FMI)
                if status == 'known':
                    FMIDiCT['known'][i+1].append(FMI)
                if status == 'novel':
                    FMIDiCT['novel'][i+1].append(FMI)

    for i in range(len(FMIDiCT['total'][1])):
        outline = ''
        for j in range(order):
            if len(FMIDiCT['total'][j+1]) > i:
                outline = outline + str(FMIDiCT['total'][j+1][i]) + '\t'
            else:                
                outline = outline + '\t'
        for j in range(order):
            if len(FMIDiCT['known'][j+1]) > i:
                outline = outline + str(FMIDiCT['known'][j+1][i]) + '\t'
            else:                
                outline = outline + '\t'
        for j in range(order):
            if len(FMIDiCT['novel'][j+1]) > i:
                outline = outline + str(FMIDiCT['novel'][j+1][i]) + '\t'
            else:                
                outline = outline + '\t'
        outfile.write(outline.strip() + '\n')
    
    outfile.close()
	
run()
