##################################
#                                #
# Last modified 08/01/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import scipy.stats
import numpy
import math
import random
from sets import Set
import time

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s FPKM_table geneNameFieldID geneIDFieldID geneFPKMFieldID GTF mu alpha minFMI outfile [-geneNameOnly] [-genetype biotype]' % sys.argv[0]
        print '\t use the mu parameter to specify the initial FMI mean'
        print '\t use the alpha parameter to specify the isoform complexity; refer to text for full details, but in short, low alpha => high isoform complexity; high alpha => low isoform complexity'
        print '\t use the minFMI parameter to indicate the FMI below which all isoforms will be assigned an FMI of 0'
        sys.exit(1)

    FPKMtable = sys.argv[1]
    geneNameFieldID = int(sys.argv[2])
    geneIDFieldID = int(sys.argv[3])
    FPKMFieldID = int(sys.argv[4])
    GTF = sys.argv[5]
    mu = float(sys.argv[6])
    alpha = float(sys.argv[7])
    minFMI = float(sys.argv[8])
    outfilename = sys.argv[9]

    doGeneNameOnly = False
    if '-geneNameOnly' in sys.argv:
        doGeneNameOnly = True

    doGeneType = False
    if '-genetype' in sys.argv:
        doGeneType = True
        biotype = sys.argv[sys.argv.index('-genetype') + 1]

    GeneDict = {}

    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#') or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        geneName = fields[geneNameFieldID]
        geneID = fields[geneIDFieldID]
        FPKM = float(fields[FPKMFieldID])
        if doGeneNameOnly:
            gene = (geneName,geneName)
        else:
            gene = (geneName,geneID)
        GeneDict[gene]={}
        GeneDict[gene]['FPKM'] = FPKM
        GeneDict[gene]['transcripts'] = {}

    lineslist = open(GTF)
    TranscriptDict={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=fields[8].split('gene_id "')[1].split('";')[0]
        if doGeneNameOnly:
            geneID = geneName
        if doGeneType:
            gene_type=fields[8].split('gene_type "')[1].split('";')[0]
            if gene_type == biotype:
                pass
            else:
                gene = (geneName,geneID)
                if GeneDict.has_key(gene):
                    del GeneDict[gene]
                continue
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=fields[8].split('transcript_id "')[1].split('";')[0]
        try:
            gene = (geneName,geneID)
            transcript = (transcriptName,transcriptID)
            GeneDict[gene]['transcripts'][transcript]=(0,0)
        except:
            print gene, fields
            continue

    for gene in GeneDict.keys():
        N = len(GeneDict[gene]['transcripts'])
        FMI_list = [1]
        lastFMI = 1
        for i in range(1,N):
            if lastFMI <= minFMI:
                FMI_list.append(0)
                continue
            mean = math.pow(mu,i*alpha)
            picked = False
            while not picked:
                FMI = scipy.stats.norm.rvs(mean,mean)
                if FMI >= 0 and FMI <= lastFMI:
                    picked = True
                    lastFMI = FMI
                    if FMI < minFMI:
                        FMI = 0
                    FMI_list.append(FMI)
                    break
        FMI_list.sort()
        FMI_list.reverse()
        i=0
        geneFPKM = GeneDict[gene]['FPKM']
        for transcript in GeneDict[gene]['transcripts'].keys():
            FMI = FMI_list[i]
            transcriptFPKM = (FMI/sum(FMI_list))*geneFPKM
            GeneDict[gene]['transcripts'][transcript] = (transcriptFPKM,FMI)
            i+=1

    outfile = open(outfilename, 'w')

    outline = '#GeneID\tGeneName\tGeneFPKM\tTranscriptName\tTranscriptID\tTranscriptFPKM\tFMI'
    outfile.write(outline + '\n')

    genes = GeneDict.keys()
    genes.sort()

    for (geneName,geneID) in genes:
        gene = (geneName,geneID)
        for (transcriptName,transcriptID) in GeneDict[gene]['transcripts'].keys():
            transcript = (transcriptName,transcriptID)
            outline = geneID + '\t' + geneName + '\t' + str(GeneDict[(geneName,geneID)]['FPKM']) + '\t' + transcriptID + '\t' + transcriptName + '\t' + str(GeneDict[gene]['transcripts'][transcript][0]) + '\t' + str(GeneDict[gene]['transcripts'][transcript][1])
            outfile.write(outline + '\n')

    outfile.close()

run()