##################################
#                                #
# Last modified 05/26/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import math
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s transcripts.chrom.sizes transcript_FPKM_table transcriptID FPKMfieldID1 FPKMFieldID2 totalNumberofReads outfilename [-splitTranscriptNames separator fieldID] [-foldChange pseudocount]' % sys.argv[0]
        sys.exit(1)

    transcript_sizes = sys.argv[1]
    FPKMtable = sys.argv[2]
    transcriptFieldID = int(sys.argv[3])
    FPKMFieldID1 = int(sys.argv[4])
    FPKMFieldID2 = int(sys.argv[5])
    NumReads = int(sys.argv[6])
    outfilename = sys.argv[7]

    doSplitNames = False
    if '-splitTranscriptNames' in sys.argv:
        doSplitNames = True
        separator = sys.argv[sys.argv.index('-splitTranscriptNames') + 1]
        SplitNameFieldID = int(sys.argv[sys.argv.index('-splitTranscriptNames') + 2])
        print 'will split names', separator, SplitNameFieldID

    TranscriptSizeDict = {}
    TranscriptNamesDict = {}

    linelist = open(transcript_sizes)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        ID = fields[0]
        TLen = int(fields[1])
        if doSplitNames:
            TranscriptNamesDict[ID.split(separator)[SplitNameFieldID]] = ID
            ID = ID.split(separator)[SplitNameFieldID]
        TranscriptSizeDict[ID] = TLen

    doFC = False
    if '-foldChange' in sys.argv:
        doFC = True
        FCPC = float(sys.argv[sys.argv.index('-foldChange') + 1])

    TR1 = 0
    TR2 = 0
    RCDict = {}

    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        ID = fields[transcriptFieldID]
        FPKM1 = float(fields[FPKMFieldID1])
        FPKM2 = float(fields[FPKMFieldID2])
        R1 = FPKM1*(TranscriptSizeDict[ID]/1000.)*(NumReads/1000000.)
        R2 = FPKM2*(TranscriptSizeDict[ID]/1000.)*(NumReads/1000000.)
        TR1 += R1
        TR2 += R2

        if doSplitNames:
            RCDict[TranscriptNamesDict[ID]] = (R1,R2)
        else:
            RCDict[ID] = (R1,R2)

    outfile = open(outfilename, 'w')
    outline = '#ID\tread_counts_1\tread_counts_2'
    if doFC:
        outline = outline + '\t' + 'fold_change'
    outfile.write(outline + '\n')

    linelist = open(transcript_sizes)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        ID = fields[0]
        (R1,R2) = RCDict[ID]
        if doFC:
            outline = ID + '\t' + str(R1/(TR1/NumReads) + FCPC) + '\t' + str(R2/(TR2/NumReads))
            if R2 == 0:
                outline = outline + '\t' + str(0)
            else:    
                outline = outline + '\t' + str((R2 + FCPC)/(R1 + FCPC))
        else:
            outline = ID + '\t' + str(R1/(TR1/NumReads)) + '\t' + str(R2/(TR2/NumReads))
        outfile.write(outline + '\n')

    outfile.close()

run()