##################################
#                                #
# Last modified 06/15/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s <list of files filename> <output filename>' % sys.argv[0]
        print '	list of files format: ExpID <tab> SAMfile <tab> junctionsfile'
        sys.exit(1)

    inputfilename = sys.argv[1]
    outputfilename = sys.argv[2]

    DataDict={}
    JunctionsDict={}

    lineslist = open(inputfilename)
    for line in lineslist:
        fields = line.strip().split('\t')
        ExpID=fields[0]
        SAM=fields[1]
        junctions=fields[2]
        DataDict[ExpID]={}
        DataDict[ExpID]['SAM']=SAM
        DataDict[ExpID]['junctions']=junctions

    for ExpID in DataDict.keys():
        lineslist = open(DataDict[ExpID]['junctions'])
        print ExpID, DataDict[ExpID]['junctions']
        for line in lineslist:
            fields=line.strip().split('\t')
            chr=fields[0]
            left=int(fields[1])
            right=int(fields[2])
            strand=fields[3]
            total=int(fields[4])
            staggered=int(fields[5])
            if JunctionsDict.has_key((chr,left,right,strand)):
                JunctionsDict[(chr,left,right,strand)][ExpID]=(total,staggered)
            else:
                JunctionsDict[(chr,left,right,strand)]={}
                JunctionsDict[(chr,left,right,strand)][ExpID]=(total,staggered)
        print ExpID, DataDict[ExpID]['SAM']
        lineslist = open(DataDict[ExpID]['SAM'])
        readIDList=[]
        i=0
        for line in lineslist:
            i+=1
            if i % 1000000 == 0:
                print ExpID, i, 'lines processed'
            fields=line.strip().split('\t')
            readID=fields[0]            
            readIDList.append(readID)
        readIDList=list(Set(readIDList))
        DataDict[ExpID]['ReadNumber']=len(readIDList)
        readIDList=[]

    outfile=open(outputfilename, 'w')
 
    Experiments=DataDict.keys()
    Experiments.sort()
    outline='#chr\tleft\tright\tstrand'
    for ExpID in Experiments:
        outline=outline+'\t'+ExpID+'_RPM'
    for ExpID in Experiments:
        outline=outline+'\t'+ExpID+'_staggered'
    outfile.write(outline+'\n')
    outline='#Reads_mapped:\t\t\t'
    for ExpID in Experiments:
        outline=outline+'\t'+str(DataDict[ExpID]['ReadNumber']/1000000)+'M'
    for ExpID in Experiments:
        outline=outline+'\t'
    outfile.write(outline+'\n')
    Junctions=JunctionsDict.keys()
    Junctions.sort()
    for (chr,left,right,strand) in Junctions:
        outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand
        for ExpID in Experiments:
            if JunctionsDict[(chr,left,right,strand)].has_key(ExpID):
                (total,staggered)=JunctionsDict[(chr,left,right,strand)][ExpID]
            else:
                (total,staggered)=(0,0)
            RPM=(total/(DataDict[ExpID]['ReadNumber']/1000000.0))
            outline=outline+'\t'+str(RPM)
        for ExpID in Experiments:
            if JunctionsDict[(chr,left,right,strand)].has_key(ExpID):
                (total,staggered)=JunctionsDict[(chr,left,right,strand)][ExpID]
            else:
                (total,staggered)=(0,0)
            outline=outline+'\t'+str(staggered)
        outfile.write(outline+'\n')
             
    outfile.close()

run()

