##################################
#                                #
# Last modified 03/08/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s TSS-expression-file [FPKM (F1,F2,F3,....Fn)] TSSradius regions regionchrfiled outfile' % sys.argv[0]
        print "        TSS-expression format: #chr    TSS     strand  transcripts     FPKM    FPKM_lo FPKM_hi"
        sys.exit(1)

    ResultsDict={}
		
    TSS = sys.argv[1]
    tresholds = sys.argv[2].split(',')
    for t in tresholds:
        t=float(t)
        ResultsDict[t]={}
        ResultsDict[t]['covered']=0
        ResultsDict[t]['notcovered']=0
    TSSradius = int(sys.argv[3])
    regions = sys.argv[4]
    chrFieldID = int(sys.argv[5])
    outfile = open(sys.argv[6],'w')

    regionDict={}
    linelist=open(regions)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[chrFieldID]
        left=int(fields[chrFieldID+1])
        right=int(fields[chrFieldID+2])
        if regionDict.has_key(chr):
            pass
        else:
            regionDict[chr]=[]
        regionDict[chr].append((left,right))

    linelist=open(TSS)
    TSSList=[]
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        TSS=int(fields[1])
        FPKM=float(fields[4])
        Covered=False
        TSSleft=TSS-TSSradius
        TSSright=TSS+TSSradius
        if chr == 'chrM':
            continue
        if regionDict.has_key(chr):
            for (left,right) in regionDict[chr]:
                if (TSSleft < left and TSSright > left) or (TSSleft < right and TSSright > right):
                    Covered=True
                    break
        else:
            pass
        TSSList.append((FPKM,chr,TSS,Covered))

    thresholds=ResultsDict.keys()
    thresholds.sort()

    for i in range(len(thresholds)-1):
        for (FPKM,chr,TSS,Covered) in TSSList:
            if FPKM >= thresholds[i] and FPKM < thresholds[i+1]:
                if Covered:
                    ResultsDict[thresholds[i]]['covered']+=1
                else:
                    ResultsDict[thresholds[i]]['notcovered']+=1
            if FPKM >= thresholds[-1]:
                if Covered:
                    ResultsDict[thresholds[-1]]['covered']+=1
                else:
                    ResultsDict[thresholds[-1]]['notcovered']+=1

    outfile.write("#TSS_FPKM\tcovered\tnot_covered\n")

    for t in thresholds:
        outline=str(t) + '\t' + str(ResultsDict[t]['covered']) + '\t' + str(ResultsDict[t]['notcovered'])
        outfile.write(outline + '\n')

    outfile.close()
   
run()
