##################################
#                                #
# Last modified 09/01/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import numpy
from numpy import linalg
import math, random, time, copy 
from operator import add

def run():

    if len(sys.argv) < 4:
        print 'usage: %s python <transcript expression from cufflinks> <gencode gtf> <TSS file> <outfile>' % sys.argv[0]
        sys.exit(1)

    expr = sys.argv[1]
    TSS = sys.argv[3]
    gtf = sys.argv[2]
    outfilename = sys.argv[4]

    linelist = open(TSS)
    TSSDict={}
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[1]
        TSS=(int(fields[2])+int(fields[3]))/2
        strand=fields[4]
        TSSDict[(chr,TSS,strand)]={}
        TSSDict[(chr,TSS,strand)]['transcripts']=[]
        TSSDict[(chr,TSS,strand)]['totalFPKM']=0
        TSSDict[(chr,TSS,strand)]['title']=fields[0]

    linelist = open(gtf)
    IDtoTSSDict={}
    c=0
    i=0
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='transcript':
            continue
        TranscriptID=fields[8].split(' transcript_id "')[1].split('";')[0]
        chr=fields[0]
        if fields[6]=='+':
            TSS=int(fields[3])
        if fields[6]=='-':
            TSS=int(fields[4])
        TSSDict[(chr,TSS,fields[6])]['transcripts'].append(TranscriptID)
        IDtoTSSDict[TranscriptID]={}
        IDtoTSSDict[TranscriptID]['TSS']=(chr,TSS,fields[6])
#        except:
#            c+=1
#            print chr,TSS,fields[6]
#            print 'could not establish TSS to transcript correspondence', c
#            pass

    outfile = open(outfilename, 'w')

    linelist = open(expr)
    FPKMfieldID=5
    for line in linelist:
        fields=line.strip().split('\t')
        if fields[0]=='trans_id':
            FPKMfieldID=fields.index('FPKM')
            print 'FPKM field ID = ', FPKMfieldID
            continue
        FPKM=float(fields[FPKMfieldID])
        TranscriptID=fields[0]
        if IDtoTSSDict.has_key(TranscriptID):
            pass
        else:
            print 'TranscriptID not found'
            continue
        (chr,TSS,strand)=IDtoTSSDict[TranscriptID]['TSS']
        TSSDict[(chr,TSS,strand)]['totalFPKM']+=FPKM
  
    keys=TSSDict.keys()
    keys.sort()
    outline='#TSS_ID\tTotal_FPKM\tchr\tTSS\tTranscripts'
    outfile.write(outline+'\n')
    for (chr,TSS,strand) in keys:
        outline=TSSDict[(chr,TSS,strand)]['title']+'\t'+str(TSSDict[(chr,TSS,strand)]['totalFPKM'])+'\t'+chr+'\t'+str(TSS)+'\t'+strand+'\t'
        for TranscriptID in TSSDict[(chr,TSS,strand)]['transcripts']:
            outline=outline+TranscriptID+','
        outfile.write(outline+'\n')
	
    outfile.close()
	
run()
