##################################
#                                #
# Last modified 06/08/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s TSS Interactions.bed outfilename' % sys.argv[0]
        print 'assumed interactions format: chr1\t32791791\t32792312\tchr1:32791791..32792312-chr11:132445143..132445663,2\t200\t.\t32791791\t32792312\t255,0,0\t1\t521\t0'
        print 'assumed TSS format: chr1\t1051636\t1051836\t-\tC1orf159\tNM_017891\tNM_017891; the middle point of the two coordinates will be used as the TSS'
        sys.exit(1)

    interactions = sys.argv[1]
    TSS = sys.argv[2]
    outputfilename = sys.argv[3]

    TSSDict={}
    GeneDict={}

    lineslist = open(TSS)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        geneName = fields[4]
        TSS = (right + left)/2
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr]={}
        if TSSDict[chr].has_key((TSS,strand)):
            pass
        else:
            TSSDict[chr][(TSS,strand)]=[]
        for gene in geneName.split(','):
            GeneDict[gene]=0
            TSSDict[chr][(TSS,strand)].append(gene)

    SeenInteractions = {}
   
    lineslist = open(interactions)
    l=0
    for line in lineslist:
        l+=1
        if l % 1000 == 0:
            print l, 'lines processed'
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if SeenInteractions.has_key(fields[3]):
            continue
        else:
            SeenInteractions[fields[3]]=0
        chr1 = fields[3].split('-')[0].split(':')[0]
        left1 = int(fields[3].split('-')[0].split(':')[1].split('..')[0])
        right1 = int(fields[3].split('-')[0].split(':')[1].split('..')[1])
        chr2 = fields[3].split('-')[1].split(':')[0]
        left2 = int(fields[3].split('-')[1].split(':')[1].split('..')[0])
        right2 = int(fields[3].split('-')[1].split(':')[1].split('..')[1].split(',')[0])
        for i in range(left1,right1):
            if TSSDict[chr1].has_key((i,'+')):
                for gene in TSSDict[chr1][i,'+']:
                    GeneDict[gene]+=1
            if TSSDict[chr1].has_key((i,'-')):
                for gene in TSSDict[chr1][i,'-']:
                    GeneDict[gene]+=1
        for i in range(left2,right2):
            if TSSDict[chr2].has_key((i,'+')):
                for gene in TSSDict[chr2][i,'+']:
                    GeneDict[gene]+=1
            if TSSDict[chr2].has_key((i,'-')):
                for gene in TSSDict[chr2][i,'-']:
                    GeneDict[gene]+=1

    outfile = open(outputfilename, 'w')

    genes = GeneDict.keys()
    genes.sort()

    outline = '#gene\tnumber_interactions'
    outfile.write(outline + '\n')

    for gene in genes:
        outline = gene + '\t' + str(GeneDict[gene])
        outfile.write(outline + '\n')

    outfile.close()

run()

