##################################
#                                #
# Last modified 03/15/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s input geneFieldID(s) transcriptFieldID(s) FPKMFieldID <outfilename>' % sys.argv[0]
        print '       field IDs comma-separated'
        sys.exit(1)

    input = sys.argv[1]
    geneIDs = []
    for ID in sys.argv[2].split(','):
        geneIDs.append(int(ID))
    geneIDs = tuple(geneIDs)
    transcriptIDs = []
    for ID in sys.argv[3].split(','):
        transcriptIDs.append(int(ID))
    transcriptIDs = tuple(transcriptIDs)
    FPKMfieldID = int(sys.argv[4])
    outfilename = sys.argv[5]

    linelist = open(input)
    GeneDict={}  
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        gene = []
        for ID in geneIDs:
            gene.append(fields[ID])
        gene = tuple(gene)
        transcript = []
        for ID in transcriptIDs:
            transcript.append(fields[ID])
        transcript = tuple(transcript)
        FPKM = float(fields[FPKMfieldID])
        if GeneDict.has_key(gene):
            pass
        else:
            GeneDict[gene]={}
        GeneDict[gene][transcript] = FPKM

    FMIDict = {}
    for gene in GeneDict.keys():
        FPKMs = []
        for transcript in GeneDict[gene].keys():
            FPKMs.append(GeneDict[gene][transcript])
        FPKMs.sort()
        FPKMs.reverse()
        if max(FPKMs) == 0:
            continue
        for i in range(len(FPKMs)):
            if FMIDict.has_key(i+1):
                pass
            else:
                FMIDict[i+1] = []
            FMIDict[i+1].append(FPKMs[i]/max(FPKMs))

    outfile = open(outfilename,'w')

    print FMIDict.keys()

    outline = '#'
    for i in FMIDict.keys():
        outline = outline + str(i) + '\t'       
    outfile.write(outline + '\n')

    keys = FMIDict.keys()
    keys.sort()

    for g in range(len(FMIDict[1])):
        outline = ''
        for i in keys:
            if len(FMIDict[i]) <= g:
                outline = outline + '\t'
            else:
                outline = outline + str(FMIDict[i][g]) + '\t'
        outfile.write(outline + '\n')

    outfile.close()
	
run()
