##################################
#                                #
# Last modified 2019/09/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os
import numpy as np
from sets import Set

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s peak_list motifs_file motifFielfID chrFieldID leftFieldID rightFieldID tabix_path mean|median outfile' % sys.argv[0]
        print '\tNote: peak_list format: label <tab> file <tab> chrFieldID'
        print '\tNote: a tabix-index motif file is assumed'
        sys.exit(1)

    peakslist = sys.argv[1]
    motifsfile = sys.argv[2]
    motifFieldID = int(sys.argv[3])
    chrFieldID = int(sys.argv[4])
    leftFieldID = int(sys.argv[5])
    rightFieldID = int(sys.argv[6])
    tabix = sys.argv[7]
    MM = sys.argv[8]
    outfilename = sys.argv[9]

    MotifCountDict = {}
    NumPeaksDict = {}
    MotifDict = {}
    LabelDict = {}

    peaksfiles = open(peakslist)
    for peakline in peaksfiles:
        peakfields = peakline.strip().split('\t')
        label = peakfields[0]
        print label
        LabelDict[label] = 1
        linelist = open(peakfields[1])
        cFieldID = int(peakfields[2])
        MotifCountDict[label] = {}
        NumPeaks = 0
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[cFieldID]
            start = int(fields[cFieldID + 1])
            end = int(fields[cFieldID + 2])
            NumPeaks += 1
            if NumPeaks % 100 == 0:
                print NumPeaks, chr, start, end
            cmd = tabix + ' ' + motifsfile + ' ' + chr + ':' + str(start) + '-' + str(end)
            p = os.popen(cmd, "r")
            lline = 'line'
            MDict = {}
            while lline != '':
                lline = p.readline().strip()
                if lline == '':
                    break
                ffields = lline.strip().split('\t')
                motif = ffields[motifFieldID]
                chr = ffields[chrFieldID]
                ML = int(ffields[leftFieldID])
                MR = int(ffields[rightFieldID])
                if MDict.has_key(motif):
                    pass
                else:
                    MDict[motif] = []
                MDict[motif].append((ML,MR))
                MotifDict[motif] = 1
            for motif in MDict.keys():
                motifs = MDict[motif]
                motifs.sort()
                newmotifs = []
                M0 = motifs[0]
                newmotifs.append(M0)
                for i in range(1,len(motifs) - 1):
                    ML2 = motifs[i][0]
                    MR2 = motifs[i][1]
                    ML1 = newmotifs[-1][0]
                    MR1 = newmotifs[-1][1]
                    if ML2 < MR1:
                        newmotifs[-1] = (ML1,MR2)
                    else:
                        newmotifs.append((ML2,MR2))
                if MotifCountDict.has_key(label):
                    pass
                else:
                    MotifCountDict[label] = {}
                if MotifCountDict[label].has_key(motif):
                    pass
                else:
                    MotifCountDict[label][motif] = []
                MotifCountDict[label][motif].append(len(newmotifs))
        NumPeaksDict[label] = NumPeaks


    outfile = open(outfilename, 'w')

    labels = LabelDict.keys()
    labels.sort()

    outline = '#Set:'
    for label in labels:
        outline = outline + '\t' + label
    outfile.write(outline + '\n')

    outline = '#NumRegions:'
    for label in labels:
        outline = outline + '\t' + str(NumPeaksDict[label])
    outfile.write(outline + '\n')

    motifs = MotifDict.keys()
    motifs.sort()

    for motif in motifs:
        outline = motif
        for label in labels:
            if MotifCountDict[label].has_key(motif):
                MotifCountDict[label][motif] += [0]*(NumPeaksDict[label] - len(MotifCountDict[label][motif]))
                if MM == 'mean':
                    outline = outline + '\t' + str(np.mean(MotifCountDict[label][motif]))
                if MM == 'median':
                    outline = outline + '\t' + str(np.median(MotifCountDict[label][motif]))
            else:
                outline = outline + '\t' + str(0)
        outfile.write(outline + '\n')

run()

