##################################
#                                #
# Last modified 08/23/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s Novel_TSS_file cuffcompare_tracking dataset_ID_list outputfilename' % sys.argv[0]
        print '	dataset_ID_list format: list of unique sample names as they appear in the cuffcompare tracking file' 
        sys.exit(1)
    
    NovelTSSfile = sys.argv[1]
    cuffcompare_tracking = sys.argv[2]
    dataIDlist = sys.argv[3]
    outfilename = sys.argv[4]

    DataIDList=[]
    lineslist=open(dataIDlist)
    for line in lineslist:
        DataIDList.append(line.strip().split('\t')[0])

    DataIDList.sort()

    NovelTSSDict={}
    KnownTSSDict={}

    ExpressionDict={}
    lineslist=open(cuffcompare_tracking)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000==0:
           print i, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        Transcript=fields[0]
        ExpressionDict[Transcript]={}
        for ID in DataIDList:
            found=False
            for field in fields[4:len(fields)]:
                if ID in field:
                    FPKM=float(field.split('|')[3])
                    FPKM_lo=float(field.split('|')[4])
                    FPKM_hi=float(field.split('|')[5])
                    ExpressionDict[Transcript][ID]=FPKM_lo
                    found=True
            if found:
                pass
            else:
                ExpressionDict[Transcript][ID]=0

    lineslist=open(NovelTSSfile)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        type=fields[5]
        strand=fields[4]
        gene=fields[0]
        ID=fields[1]
        chr=fields[2]
        TSS=int(fields[3])
        transcripts=fields[6].split(',')[0:-1]
        if type=='known':
            if KnownTSSDict.has_key(gene):
                pass
            else:
                KnownTSSDict[gene]={}
                KnownTSSDict[gene]['ID']=ID
                KnownTSSDict[gene]['chr']=chr
                KnownTSSDict[gene]['TSSs']=[]
                KnownTSSDict[gene]['strand']=strand
            KnownTSSDict[gene]['TSSs'].append(TSS)
        if type=='novel':
            if NovelTSSDict.has_key(gene):
                pass
            else:
                NovelTSSDict[gene]={}
                NovelTSSDict[gene]['ID']=ID
                NovelTSSDict[gene]['chr']=chr
                NovelTSSDict[gene]['TSSs']=[]
                NovelTSSDict[gene]['strand']=strand
                NovelTSSDict[gene]['TSSs']={}
            NovelTSSDict[gene]['TSSs'][TSS]=transcripts

    outfile = open(outfilename, 'w')

    outline='#gene\tID\tchr\tTSS\tstrand\tknown/novel\tclosest_known_TSS\ttranscripts'
    for ID in DataIDList:
        outline=outline+'\t'+ID
    outfile.write(outline+'\n')

    NovelTSSDictkeys=NovelTSSDict.keys()
    NovelTSSDictkeys.sort()
    for gene in NovelTSSDictkeys:
        chr=NovelTSSDict[gene]['chr']
        ID=NovelTSSDict[gene]['ID']
        strand=NovelTSSDict[gene]['strand']
        TSSkeys=NovelTSSDict[gene]['TSSs'].keys()
        TSSkeys.sort()
        line=gene+'\t'+ID+'\t'+chr+'\t'
        for TSS in TSSkeys:
            outline=line+str(TSS)+'\t'+strand+'\tnovel\t'
            closest=100000000000000
            for knownTSS in KnownTSSDict[gene]['TSSs']:
                distance=TSS-knownTSS
                if math.fabs(distance) <= math.fabs(closest):
                    closest=distance
            if strand=='-':
                closest=-closest
            outline=outline+str(closest)+'\t'
            for transcript in NovelTSSDict[gene]['TSSs'][TSS]:
                outline=outline+transcript+','
            outline=outline[0:-1]+'\t'
            for ID in DataIDList:
                total_FPKM_lo=0
                for transcript in NovelTSSDict[gene]['TSSs'][TSS]:
                    try:
                        total_FPKM_lo+=ExpressionDict[transcript][ID]
                    except:
                        print transcript, 'not found in expression data'
                outline=outline+str(total_FPKM_lo)+'\t'
            outfile.write(outline.strip()+'\n')

    outfile.close()
   
run()
