##################################
#                                #
# Last modified 2017/05/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
import os	

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s FIMO.txt regions chrFieldID outfilename' % sys.argv[0]
        print '\tNote: the FIMO files can be in .gz or .bzip2 format'
        sys.exit(1)

    FIMO = sys.argv[1]
    regions = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    outfilename = sys.argv[4]

    maxMotLen = 0

    MotDict = {}
    FIMODict = {}

    if FIMO.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + FIMO
    elif FIMO.endswith('.gz'):
        cmd = 'gunzip -c ' + FIMO
    else:
        cmd = 'cat ' + FIMO
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        mot = fields[0]
        MotDict[mot] = 1
        chr = fields[1] 
        left = int(fields[2])
        right = int(fields[3])
        if (right - left) > maxMotLen:
            maxMotLen = (right - left)
        strand = fields[4]
        score = fields[5]
        pval = fields[6]
        if FIMODict.has_key(chr):
            pass
        else:
            FIMODict[chr] = {}
        if FIMODict[chr].has_key(mot):
            pass
        else:
            FIMODict[chr][mot]={}
        FIMODict[chr][mot][left] = (right)

    print 'finished importing motifs'

    allmotifs = MotDict.keys()
    allmotifs.sort()

    outfile = open(outfilename, 'w')

    k = 0 
    linelist = open(regions)
    for line in linelist: 
        k+=1
        if k % 100 == 0:
            print k, 'lines processed'
        if line.startswith('#'):
            outline = line.strip()
            for mot in allmotifs:
                outline = outline + '\t' + mot
            outfile.write(outline + '\n')
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[chrFieldID + 1])
        right = int(fields[chrFieldID + 2])
        outline = line.strip()
        for mot in allmotifs:
            MC = 0
            if FIMODict.has_key(chr):
                if FIMODict[chr].has_key(mot):
                    motOccurences = []
                    for i in range(left,right):
                        if FIMODict[chr][mot].has_key(i):
                            motOccurences.append((i,FIMODict[chr][mot][i]))
                    if len(motOccurences) > 0:
                        motOccurences.sort()
                        motOccurencesFiltered = []
                        motOccurencesFiltered.append(motOccurences[0])
                        for j in range(len(motOccurences) - 1):
                            if motOccurences[j+1][0] > motOccurences[j][0] and motOccurences[j+1][0] < motOccurences[j][1]:
                                pass
                            else:
                                motOccurencesFiltered.append(motOccurences[j+1])
                        motOccurencesFiltered = list(Set(motOccurencesFiltered))
                        MC = len(motOccurencesFiltered)
            outline = outline + '\t' + str(MC)
        outfile.write(outline + '\n')

    outfile.close()
   
run()
