##################################
#                                #
# Last modified 03/11/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s TSS_FPKM_table transcsriptsFieldID FPKM_fields FPKM_thresholds GTF outputfilename [-skipFAIL]' % sys.argv[0]
        print 'FPKM fields: comma-spearated, or in start:end (inclusive) format, or any combination of'
        print 'FPKM thresholds: comma-spearated'
        sys.exit(1)

    doSkipFAIL = False
    if '-skipFAIL' in sys.argv:
        doSkipFAIL = True

    FPKMtable = sys.argv[1]
    transcriptFieldID = int(sys.argv[2])
    FPKMFieldIDs = sys.argv[3].split(',')
    FPKMFields = []
    for FPKM in FPKMFieldIDs:
        if ':' in FPKM:
            start = int(FPKM.split(':')[0])
            end = int(FPKM.split(':')[1])
            for i in range(start,end+1):
                FPKMFields.append(i)
        else:
                FPKMFields.append(int(FPKM))
    FPKM_thresholds_list = sys.argv[4].split(',')
    FPKM_thresholds = []
    for FPKM in FPKM_thresholds_list:
        FPKM_thresholds.append(float(FPKM))
    GTF = sys.argv[5]
    outfilename = sys.argv[6]

    TranscriptDict={}
    GeneDict={}

    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if GeneDict.has_key((geneID,geneName)):
            pass
        else:
            GeneDict[(geneID,geneName)]={}
        GeneDict[(geneID,geneName)] = {}
        TranscriptDict[transcriptID] = (geneID,geneName)

    GenesInTable = {}  
  
    linelist=open(FPKMtable)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        max_FPKM = 0
        for ID in FPKMFields:
            if doSkipFAIL:
                if ',' in fields[ID]:
                    if fields[ID].split(',')[1] != 'OK':
                        FPKM = 0
                    else:
                        FPKM = float(fields[ID].split(',')[0])
                else:
                    if fields[ID] == 'FAIL':
                        FPKM = 0
                    else:
                        FPKM = float(fields[ID].split(',')[0])
            else:
                FPKM = float(fields[ID].split(',')[0])
            if FPKM > max_FPKM:
                max_FPKM = FPKM
        transcriptIDs = fields[transcriptFieldID].split(',')
        for transcriptID in transcriptIDs:
            (geneID,geneName) = TranscriptDict[transcriptID]
            GenesInTable[(geneID,geneName)] = 0
            GeneDict[(geneID,geneName)][fields[transcriptFieldID]] = max_FPKM                 

    outfile = open(outfilename, 'w')
    outline = '#GeneID\tGeneName\tNumber_TSSs\t'
    FPKM_thresholds.sort()
    for FPKM in FPKM_thresholds:
        outline = outline + str(FPKM) + '\t'
    outfile.write(outline.strip()+'\n')

    for (geneID,geneName) in GenesInTable.keys():
        outline = geneID + '\t' + geneName + '\t' + str(len(GeneDict[(geneID,geneName)].keys()))
        for FPKM in FPKM_thresholds:
            passing = 0
            for transcriptIDs in GeneDict[(geneID,geneName)].keys():
                if GeneDict[(geneID,geneName)][transcriptIDs] > FPKM:
                    passing+=1
            outline = outline + '\t' + str(passing)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
