##################################
#                                #
# Last modified 2017/07/11	 # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s FIMO.txt motifs.meme minOverlap(bp)' % sys.argv[0]
        print '\tNote: the fimo.txt file can be .gz or .bz2'
        print '\tNote: the script will print to stdout'
        sys.exit(1)

    FIMO = sys.argv[1]
    JASPAR = sys.argv[2]
    minOL = int(sys.argv[3])

    JASPARDict = {}

    linelist = open(JASPAR)
    InMotif = False
    for line in linelist:
        if line.startswith('MOTIF '):
            motifID = line.strip().split(' ')[1]
            TF = line.strip().split(' ')[2].replace(':','_')
            JASPARDict[motifID] = TF
            continue
        else:
            continue

    if FIMO.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + FIMO
    elif FIMO.endswith('.gz'):
        cmd = 'gunzip -c ' + FIMO
    else:
        cmd = 'cat ' + FIMO
    p = os.popen(cmd, "r")

    line = 'line'

    MotifDict = {}

#    c = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
#        c += 1
#        if c % 5000000 == 0:
#            print str(c/1000000) + 'M lines processed in FIMO file'
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        motifID = fields[0]
        chr = fields[1]
        left = int(fields[2])
        right = int(fields[3])
        strand = fields[4]
        TF = JASPARDict[motifID]
        if MotifDict.has_key(TF):
            pass
        else:
            MotifDict[TF] = {}
        if MotifDict[TF].has_key(chr):
            pass
        else:
            MotifDict[TF][chr] = []
        MotifDict[TF][chr].append((left,right,strand))

    outline = '#TF\tTotalMotifhits\tBiDirectionalHits\tFractionBiDirectionalHits'
    print outline

    for TF in MotifDict.keys():
        TotalHits = 0
        BiDirectionalHits = 0
        for chr in MotifDict[TF].keys():
            MotifDict[TF][chr].sort()
            TotalHits += len(MotifDict[TF][chr])
            if len(MotifDict[TF][chr]) > 1:
                for i in range(len(MotifDict[TF][chr])):
                    BD = False
                    if i == 0:
                        (left,right,strand) = MotifDict[TF][chr][i]
                        (leftNext,rightNext,strandNext) = MotifDict[TF][chr][i + 1]
                        if strandNext != strand and leftNext >= left and leftNext <= right:
                            OL = min(right,rightNext) - max(left,leftNext)
                            if OL >= minOL:
                                BD = True
                    elif i == len(MotifDict[TF][chr]) - 1:
                        (left,right,strand) = MotifDict[TF][chr][i]
                        (leftPrev,rightPrev,strandPrev) = MotifDict[TF][chr][i - 1]
                        if strandPrev != strand and left >= leftPrev and left <= rightPrev:
                            OL = min(right,rightPrev) - max(left,leftPrev)
                            if OL >= minOL:
                                BD = True
                    else:
                        (left,right,strand) = MotifDict[TF][chr][i]
                        (leftPrev,rightPrev,strandPrev) = MotifDict[TF][chr][i - 1]
                        (leftNext,rightNext,strandNext) = MotifDict[TF][chr][i + 1]
                        if strandNext != strand and leftNext >= left and leftNext <= right:
                            OL = min(right,rightNext) - max(left,leftNext)
                            if OL >= minOL:
                                BD = True
                        if strandPrev != strand and left >= leftPrev and left <= rightPrev:
                            OL = min(right,rightPrev) - max(left,leftPrev)
                            if OL >= minOL:
                                BD = True
                    if BD:
                        BiDirectionalHits += 1
        outline = TF + '\t' + str(TotalHits) + '\t' + str(BiDirectionalHits) + '\t' + str(BiDirectionalHits/(TotalHits + 0.0))
        print outline

run()

