##################################
#                                #
# Last modified 2024/03/26       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s input outputfilename [-subfields separator geneFieldID(s)] [-renormalize]' % sys.argv[0]
        print 'Note1: the following input format is assumed: transcript_id   gene_id length  effective_length        expected_count  TPM     FPKM    IsoPct'
        print 'Note2: it is assumed that gene names and transcript IDs are separated by :, i.e. hg19:A1BGAS:NR_015380 or just A1BGAS:NR_015380; everything after the last : will be considered unique transcript ID while everything before it will be considered common gene ID/Name'
        print '       if the transcirpt names are in some other format, use the [-subfields separator gene] option (geneFieldID(s) separated by comma'
        print '       if fewer subfields are found than required, only the first that are found will be used'
        print '       use the -TPM option for version 1.5.0 and above'
        sys.exit(1)
    
    input = sys.argv[1]
    outfilename = sys.argv[2]
    outfile = open(outfilename, 'w')

    GeneDict={}

    doRenorm = False
    if '-renormalize' in sys.argv:
        doRenorm = True
        print 'will renormalize TPMs'

    doSpecialFields = False
    if '-subfields' in sys.argv:
        GeneSubFieldIDs = []
        doSpecialFields = True
        separator = sys.argv[sys.argv.index('-subfields') + 1]
        IDs = sys.argv[sys.argv.index('-subfields') + 2].split(',')
        outline = '#'
        for ID in IDs:
            GeneSubFieldIDs.append(int(ID))
            outline = outline + 'geneID\t'
        outline = outline + 'expected_count\tFPKM\tTPM'
    else:
        outline = '#geneName\texpected_count\tFPKM\tTPM'
    outfile.write(outline + '\n')

    ReNorm = 0

    linelist = open(input)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('gene_id\t'):
            TotCountFieldID = fields.index('expected_count')
            FPKMID = fields.index('FPKM')
            TPMID = fields.index('TPM')
            continue
        ID = fields[1]
        TotCount = float(fields[TotCountFieldID])
        FPKM = float(fields[FPKMID])
        TPM = float(fields[TPMID])
        if doSpecialFields:
            IDfields = ID.split(separator)
            geneName = []
            for GSFID in GeneSubFieldIDs:
                if len(IDfields) > GSFID:
                    geneName.append(IDfields[GSFID])
                    last = IDfields[GSFID]
                else:
                    geneName.append(last)
            geneName = tuple(geneName)
        else:
            IDfields = ID.split(':')
            if len(IDfields) == 1:
                geneName = IDfields[0]
            else:
                geneName = ''
                for i in range(len(IDfields)-1):
                    geneName = geneName + IDfields[i] + ':'
                geneName = geneName[0:-1]
#            print ID, IDfields, geneName
        if GeneDict.has_key(geneName):
            GeneDict[geneName]['FPKM'] += FPKM
            GeneDict[geneName]['TPM'] += TPM
            GeneDict[geneName]['TotCount'] += TotCount
        else:
            GeneDict[geneName]={}
            GeneDict[geneName]['FPKM'] = FPKM
            GeneDict[geneName]['TPM'] = TPM
            GeneDict[geneName]['TotCount'] = TotCount

    genes = GeneDict.keys()
    genes.sort()

    print 'total TPM:', ReNorm

    for geneName in genes:
        if doSpecialFields:
            outline = ''
            for i in range(len(geneName)):
                outline = outline + geneName[i] + '\t'
            outline = outline + str(GeneDict[geneName]['TotCount'])  + '\t' + str(GeneDict[geneName]['FPKM'])
            if doRenorm:
                outline = outline + '\t' + str((GeneDict[geneName]['TPM']/ReNorm)*1000000)
            else:
                outline = outline + '\t' + str(GeneDict[geneName]['TPM'])
        else:
            outline = outline + str(GeneDict[geneName]['TotCount'])  + '\t' + str(GeneDict[geneName]['FPKM'])
            if doRenorm:
                outline = outline + '\t' + str((GeneDict[geneName]['TPM']/ReNorm)*1000000)
            else:
                outline = outline + '\t' + str(GeneDict[geneName]['TPM'])
        outfile.write(outline + '\n')                        

    outfile.close()
   
run()