##################################
#                                #
# Last modified 06/30/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s <annotation gtf> <list of region calls filename> <TSS upstream> <TSS downstream> <outfilename>' % sys.argv[0]
        print '	format of list of MACS peaks.xls files:>' 
        print '	TFname,path>' 
        print 'Note1: label all replicates with the same TFname>' 
        print 'Note2: a TSS will be considered a target only if it is the closest one to the peak>' 
        sys.exit(1)

    gtf = sys.argv[1]
    regions = sys.argv[2]
    TSSup = int(sys.argv[5])
    TSSdown = int(sys.argv[6])
    outputfilename = sys.argv[7]

    minDist=min(TSSup,TSSdown)
    geneDict={}

    AnnotationDict={}
    lineslist = open(gtf)
    i=0
    for line in lineslist:
        fields = line.strip().split('\t')
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        if fields[2]!='transcript':
            continue
        chr=fields[0]
        strand=fields[6]
        if strand=='+':
            TSS=int(fields[3])
        if strand=='-':
            TSS=int(fields[4])
        if AnnotationDict.has_key(chr):
            pass
        else:
            AnnotationDict[chr]={}
        gene=fields[8].split('gene_name "')[1].split('";')[0]
        AnnotationDict[chr][(TSS,strand)]=gene
        geneDict[gene]={}
 
    print 'finished inputting annotation'
    lineslist = open(regions)
    peakFileDict={}
    TargetDict={}
    for line in lineslist:
        fields = line.strip().split(',')
        TF=fields[0]
        file=fields[1]
        if peakFileDict.has_key(TF):
            pass
            peakFileDict[TF].append(file)
            TargetDict[TF][file]=[]
        else:
            peakFileDict[TF]=[]
            peakFileDict[TF].append(file)
            TargetDict[TF]={}
            TargetDict[TF][file]=[]

    print 'processing peak calls'

    DataDict={}
    for TF in peakFileDict:
        DataDict[TF]=len(peakFileDict[TF])
        for gene in geneDict:
            geneDict[gene][TF]=0

    for TF in peakFileDict:
        print TF
        for file in peakFileDict[TF]:
            print 'processing', file
            lineslist=open(file)
            for line in lineslist:
                if line[0]=='#':
                    continue
                fields=line.strip().split('\t')
                if fields[0]=='chr':
                    continue
                chr=fields[0]
                peak=int(fields[1])+ int(fields[4])
                RPM=float(fields[6])
                closest=5000000000
                target=''
                if AnnotationDict.has_key(chr):
                    pass
                else:
                    continue
                for (TSS,strand) in AnnotationDict[chr].keys():
                    if math.fabs(peak-TSS) > minDist:
                        continue
                    if strand=='+':
                        if peak-TSS > (-1)*TSSup and peak-TSS < TSSdown:
                            AbsDist=math.fabs(peak-TSS)
                            if AbsDist < closest:
                                closest=AbsDist
                                target=AnnotationDict[chr][(TSS,strand)]
                    if strand=='-':
                        if peak-TSS > (-1)*TSSdown and peak-TSS < TSSup:
                            AbsDist=math.fabs(peak-TSS)
                            if AbsDist < closest:
                                closest=AbsDist
                                target=AnnotationDict[chr][(TSS,strand)]
                if target!='':
                    TargetDict[TF][file].append((target,RPM))
                    
    print 'summarizing output stats'

    for TF in peakFileDict:
        for file in peakFileDict[TF]:
            print TF, file
            for (gene,RPM) in TargetDict[TF][file]:
                geneDict[gene][TF]+=(RPM/len(peakFileDict[TF]))

    outfile = open(outputfilename, 'w')
    TFList=DataDict.keys()
    TFList.sort()
    
    GeneList=geneDict.keys()
    GeneList.sort()

    for gene in GeneList:
        for TF in TFList:
            if geneDict[gene][TF]>0:
                outline=TF+'\tpd\t'+gene+'\t'+str(geneDict[gene][TF])
                outfile.write(outline+'\n')

    outfile.close()

run()

