##################################
#                                #
# Last modified 11/10/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s <TSS file> <ChIP wig> <Input wig> <TSS radius> <downstream bp> pseudocount outfilename ' % sys.argv[0]
        print '     TSS file format: chr TSS strand'
        sys.exit(1)
    
    TSS = sys.argv[1]
    ChIP = sys.argv[2]
    Input = sys.argv[3]
    radius = int(sys.argv[4])
    downstream = int(sys.argv[5])
    pseudocount = float(sys.argv[6])
    outfilename = sys.argv[7]

    ChIPCoverageDict={}
    InputCoverageDict={}

    TSSDict={}

    linelist = open(TSS)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        TSS = int(fields[1])
        strand = fields[2]
        TSSDict[(chr,TSS,strand)]={}
        if ChIPCoverageDict.has_key(chr):
            pass
        else:
            ChIPCoverageDict[chr]={}
            InputCoverageDict[chr]={}
        if strand == '+':
            for i in range(TSS-radius,TSS+radius+downstream):
                ChIPCoverageDict[chr][i]=0
                InputCoverageDict[chr][i]=0
        if strand == '-':
            for i in range(TSS-radius-downstream,TSS+radius):
                ChIPCoverageDict[chr][i]=0
                InputCoverageDict[chr][i]=0


    linelist = open(ChIP)
    c=0
    for line in linelist:
        c+=1
        if c % 1000000 == 0:
            print str(c/1000000) + 'M lines in ChIP processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        score = float(fields[3])
        if ChIPCoverageDict.has_key(chr):
            for i in range(left,right):
                if ChIPCoverageDict[chr].has_key(i):
                    ChIPCoverageDict[chr][i]=score

    linelist = open(Input)
    c=0
    for line in linelist:
        c+=1
        if c % 1000000 == 0:
            print str(c/1000000) + 'M lines in Input processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        score = float(fields[3])
        if InputCoverageDict.has_key(chr):
            for i in range(left,right):
                if InputCoverageDict[chr].has_key(i):
                    InputCoverageDict[chr][i]=score

    outfile = open(outfilename, 'w')

    keys=TSSDict.keys()
    keys.sort()

    outline = '#chr\tTSS\tstrand\tTSS_ChIP_signal\tTSS_input_signal\tDownstream_ChIP_signal\tDownstream_input_signal\tStalling_Ratio'
    outfile.write(outline+'\n')

    for (chr,TSS,strand) in keys:
        if TSSDict.has_key((chr,TSS,strand)):
            pass
        else:
            continue
        Clean=True
        if strand == '+':
            for i in range(TSS-radius,TSS+radius+downstream):
                if i != TSS:
                    if TSSDict.has_key((chr,i,'+')):
                        del TSSDict[(chr,i,'+')]
                        Clean=False                    
                    if TSSDict.has_key((chr,i,'-')):
                        del TSSDict[(chr,i,'-')]
                        Clean=False                    
        if strand == '-':
            for i in range(TSS-radius-downstream,TSS+radius):
                if i != TSS:
                    if TSSDict.has_key((chr,i,'+')):
                        del TSSDict[(chr,i,'+')]
                        Clean=False                    
                    if TSSDict.has_key((chr,i,'-')):
                        del TSSDict[(chr,i,'-')]
                        Clean=False                    
        if not Clean:
            del TSSDict[(chr,TSS,strand)]
            continue
        ChIPTSSSignal=0
        InputTSSSignal=0
        ChIPDownstreamSignal=0
        InputDownstreamSignal=0
        if strand == '+':
            for i in range(TSS-radius,TSS+radius):
                ChIPTSSSignal += ChIPCoverageDict[chr][i]
                InputTSSSignal += InputCoverageDict[chr][i]
            for i in range(TSS+radius,TSS+radius+downstream):
                ChIPDownstreamSignal += ChIPCoverageDict[chr][i]
                InputDownstreamSignal += InputCoverageDict[chr][i]
        if strand == '-':
            for i in range(TSS-radius,TSS+radius):
                ChIPTSSSignal += ChIPCoverageDict[chr][i]
                InputTSSSignal += InputCoverageDict[chr][i]
            for i in range(TSS-radius-downstream,TSS-radius):
                ChIPDownstreamSignal += ChIPCoverageDict[chr][i]
                InputDownstreamSignal += InputCoverageDict[chr][i]
        StallingRatio = (max(ChIPTSSSignal - InputTSSSignal,0) + pseudocount) / (max(ChIPDownstreamSignal - InputDownstreamSignal,0) + pseudocount)
        outline = chr + '\t' + str(TSS) + '\t' + strand + '\t' + str(ChIPTSSSignal) + '\t' + str(InputTSSSignal) + '\t' + str(ChIPDownstreamSignal) + '\t' + str(InputDownstreamSignal) + '\t' + str(StallingRatio)
        outfile.write(outline+'\n')

    outfile.close()
   
run()
