##################################
#                                #
# Last modified 11/29/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s input outputfilename [-subfields separator geneFieldID(s)]' % sys.argv[0]
        print '\t# Transcript	Length	TPM	RPKM'
        print '\tit is assumed that gene names and transcript IDs are separated by :, i.e. hg19:A1BGAS:NR_015380 or just A1BGAS:NR_015380; everything after the last : will be considered unique transcript ID while everything before it will be considered common gene ID/Name'
        print '\tif the transcirpt names are in some other format, use the [-subfields separator gene] option (geneFieldID(s) separated by comma'
        print '       if fewer subfields are found than required, only the first that are found will be used'
        print 'Note3: only the sum of RPKM and TPM will be outputted'
        sys.exit(1)
    
    input = sys.argv[1]
    outfilename = sys.argv[2]
    outfile = open(outfilename, 'w')

    GeneDict={}

    doSpecialFields = False
    if '-subfields' in sys.argv:
        GeneSubFieldIDs = []
        doSpecialFields = True
        separator = sys.argv[sys.argv.index('-subfields') + 1]
        IDs = sys.argv[sys.argv.index('-subfields') + 2].split(',')
        outline = '#'
        for ID in IDs:
            GeneSubFieldIDs.append(int(ID))
            outline = outline + 'geneID\t'
        outline = outline + '\tRPKM\tTPM'
    else:
        outline = '#geneName\tRPKM\tTPM'
    outfile.write(outline + '\n')

    linelist = open(input)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('#'):
            if line.startswith('# Transcript\t'):
                RPKMID = fields.index('RPKM')
                TPM_ID = fields.index('TPM')
            continue
        ID = fields[0]
        RPKM = float(fields[RPKMID])
        TPM = float(fields[TPM_ID])
        if doSpecialFields:
            IDfields = ID.split(separator)
            geneName = []
            for GSFID in GeneSubFieldIDs:
                if len(IDfields) > GSFID:
                    geneName.append(IDfields[GSFID])
                    last = IDfields[GSFID]
                else:
                    geneName.append(last)
            geneName = tuple(geneName)
        else:
            IDfields = ID.split(':')
            if len(IDfields) == 1:
                geneName = IDfields[0]
            else:
                geneName = ''
                for i in range(len(IDfields)-1):
                    geneName = geneName + IDfields[i] + ':'
                geneName = geneName[0:-1]
#            print ID, IDfields, geneName
        if GeneDict.has_key(geneName):
            GeneDict[geneName]['RPKM'] += RPKM
            GeneDict[geneName]['TPM'] += TPM
        else:
            GeneDict[geneName]={}
            GeneDict[geneName]['RPKM'] = RPKM
            GeneDict[geneName]['TPM'] = TPM

    genes = GeneDict.keys()
    genes.sort()

    for geneName in genes:
        if doSpecialFields:
            outline = ''
            for i in range(len(geneName)):
                 outline = outline + geneName[i] + '\t'
            outline = outline + '\t' + str(GeneDict[geneName]['RPKM']) + '\t' + str(GeneDict[geneName]['TPM'])
        else:
            outline = geneName + '\t' + str(GeneDict[geneName]['FPKM']) + '\t' + str(GeneDict[geneName]['TPM'])
        outfile.write(outline + '\n')                        

    outfile.close()
   
run()