##################################
#                                #
# Last modified 10/28/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s RPM_table label_fields RPM_files SAMstats_config_file outfilename'
        print '\t fields should be comma-separated or in "start:end" format, start and end included' 
        print '\format of config file: label (has to match header column entry) <tab> mapped_reads_for_table_entry <tab> total_mapped_reads' 
        sys.exit(1)

    table = sys.argv[1]
    labels = []
    fields = sys.argv[2].split(',')
    for ID in fields:
        if ':' in ID:
            start = int(ID.split(':')[0])
            end = int(ID.split(':')[1])
            for i in range(start,end+1):
                labels.append(i)
        else:
            labels.append(int(ID))
    labels.sort()
    values = []
    fields = sys.argv[3].split(',')
    for ID in fields:
        if ':' in ID:
            start = int(ID.split(':')[0])
            end = int(ID.split(':')[1])
            for i in range(start,end+1):
                values.append(i)
        else:
            values.append(int(ID))
    values.sort()
    config=sys.argv[4]
    outfilename = sys.argv[5]

    outfile = open(outfilename, 'w')

    splitBy = '\t'

    ReadStatsDict = {}
    linelist = open(config)
    for line in linelist:
        if line[0]=='#':
            continue
        fields = line.strip().split(splitBy)
        label = fields[0]        
        table_RPM_reads = int(fields[1])
        total_reads = int(fields[2])
        ReadStatsDict[label]=(table_RPM_reads,total_reads)

    DataList=[]
    linelist  = open(table)
    t=0
    for line in linelist:
        t+=1
        if t % 10000 == 0:
            print t, 'lines processed'
        if line[0]=='#':
            fields = line.strip().split(splitBy)
            fieldToLabelDict={}
            for ID in values:
                fieldToLabelDict[ID] = fields[ID]
            outfile.write(line)
            continue
        fields = line.strip().split(splitBy)
        outline = ''
        for ID in labels:
            outline = outline + fields[ID] + '\t'
        for ID in values:
            RPM = float(fields[ID])
            label = fieldToLabelDict[ID]
            (table_RPM_reads,total_reads) = ReadStatsDict[label]
            reads = (RPM/1000000)*table_RPM_reads
            newRPM = reads/(total_reads/1000000)
            outline = outline + str(newRPM) + '\t'
        outfile.write(outline.strip()+'\n')

    outfile.close()
        
run()

