##################################
#                                #
# Last modified 2016/10/23       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s PWMmatches1.chrom.sizes PWMmatches1 PWMmatches2 outfilename [-sameStrandOnly] [-file1NarrowPeak] [-sortBy filename IDfield(s) scorefield]' % sys.argv[0]
        print '\tNote: the script will center motifs on the first set of matches, and on the motif instance that has a closest match in the second set'
        print '\tNote2: It will reverse the coordinates is the center motif is on the minus strand'
        print '\tNote3: sortBy option: if only a single ID field it is assumed that it contaisn IDs that are the same is in the PWMmatches files; if 3 fields are specified, the coordinates will be reconstituted from those; sorting will be in descending order'
        sys.exit(1)

    doSS = False
    if '-sameStrandOnly' in sys.argv:
        doSS = True

    RegionLenDict = {}
    PWMmatch1Dict = {}
    PWMmatch2Dict = {}

    doF1NP = False
    if '-file1NarrowPeak' in sys.argv:
        doF1NP = True
        print 'will treat first file as narrowPeak'

    linelist = open(sys.argv[1])
    for line in linelist:
        fields = line.strip().split('\t')
        ID = fields[0]
        PWMmatch1Dict[ID] = {}
        RegionLenDict[ID] = int(fields[1])

    doSort = False
    if '-sortBy' in sys.argv:
        doSort = True
        print 'will sort regions'
        scoresList = []
        sortfile = sys.argv[sys.argv.index('-sortBy') + 1]
        scoreID = int(sys.argv[sys.argv.index('-sortBy') + 3])
        fieldIDs = sys.argv[sys.argv.index('-sortBy') + 2]
        linelist = open(sortfile)
        if ',' in fieldIDs:
            fields = fieldIDs.split(',')
            if len(fields) != 3:
                print 'incorrect number of sorting ID fields specified, exiting'
                sys.exit(1)
            IDs = []
            for f in fields:
                IDs.append(int(f))
            for line in linelist:
                if line.startswith('#'):
                    continue
                fields = line.strip().split('\t')
                score = float(fields[scoreID])
                ID = fields[IDs[0]] + ':' + fields[IDs[1]] + '-' + fields[IDs[2]]
                if RegionLenDict.has_key(ID):
                    scoresList.append((score,ID))
        else:
            IDfield = int(sys.argv[sys.argv.index('-sortBy') + 2])
            for line in linelist:
                if line.startswith('#'):
                    continue
                fields = line.strip().split('\t')
                score = float(fields[scoreID])
                ID = fields[IDfield]
                if RegionLenDict.has_key(ID):
                    scoresList.append((score,ID))
        scoresList.sort()
        scoresList.reverse()

    linelist = open(sys.argv[2])
    if doF1NP:
        for line in linelist:
            fields = line.strip().split('\t')
            chr = fields[0]
            regionleft = int(fields[1])
            regionright = int(fields[2])
            peak = int(fields[9]) + regionleft
#            print fields
#            print chr, left, right, peak
            for ID in PWMmatch1Dict.keys():
                IDchr = ID.split(':')[0]
                left = int(ID.split(':')[1].split('-')[0])
                right = int(ID.split(':')[1].split('-')[1])
                if IDchr == chr and peak >= left and peak <= right:
                    if PWMmatch1Dict.has_key(ID):
                        pass
                    else:
                        PWMmatch1Dict[ID] = {}
                    if PWMmatch1Dict[ID].has_key(peak-left):
                        pass
                    else:
                        PWMmatch1Dict[ID][peak-left] = {}
                    PWMmatch1Dict[ID][peak-left]['-'] = 1
                    PWMmatch1Dict[ID][peak-left]['+'] = 1
    else:
        for line in linelist:
            fields = line.strip().split('\t')
            ID = fields[0]
            pos1 = int(fields[1])
            pos2 = int(fields[2])
            middle = (pos1 + pos2)/2.
            strand = fields[3]
            if PWMmatch1Dict.has_key(ID):
                pass
            else:
                PWMmatch1Dict[ID] = {}
            if PWMmatch1Dict[ID].has_key(middle):
                pass
            else:
                PWMmatch1Dict[ID][middle] = {}
            PWMmatch1Dict[ID][middle][strand] = 1

#    print PWMmatch1Dict

    linelist = open(sys.argv[3])
    for line in linelist:
        fields = line.strip().split('\t')
        ID = fields[0]
        pos1 = int(fields[1])
        pos2 = int(fields[2])
        middle = (pos1 + pos2)/2.
        strand = fields[3]
        if PWMmatch2Dict.has_key(ID):
            pass
        else:
            PWMmatch2Dict[ID] = {}
        if PWMmatch2Dict[ID].has_key(middle):
            pass
        else:
            PWMmatch2Dict[ID][middle] = {}
        PWMmatch2Dict[ID][middle][strand] = 1

    outputfilename = sys.argv[4]
    outfile = open(outputfilename, 'w')

    if doSort:
        IDs = []
        for (score,ID) in scoresList:
            IDs.append(ID)
    else:
        IDs = RegionLenDict.keys()
        IDs.sort()

    outline = '#Note: the coordinate for region with minus-strand center motifs are reversed!!!' 
    outfile.write(outline + '\n')
    outline = '#ID\tNumber\tMotif1Pos\tMotif2Pos\tStrand1\tStrand2'
    outfile.write(outline + '\n')

    i=0
    for ID in IDs:
        i+=1
        if len(PWMmatch1Dict[ID].keys()) == 0:
            center = int(RegionLenDict[ID])/2
            if PWMmatch2Dict.has_key(ID):
                for middle2 in PWMmatch2Dict[ID].keys():
                    if PWMmatch2Dict[ID][middle2].has_key('+') and PWMmatch2Dict[ID][middle2].has_key('-'):
                        strand2 = '+-'
                    else:
                        strand2 = PWMmatch2Dict[ID][middle2].keys()[0]
                    outline = ID + '\t' + str(i) + '\t' + 'nan' + '\t' + str(middle2 - center) + '\t' + 'nan' + '\t' + strand2
                    outfile.write(outline + '\n')
            else:
                outline = ID + '\t' + str(i) + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan' + '\t' + 'nan'
                outfile.write(outline + '\n')
            continue
        ND = RegionLenDict[ID] + 1
        NDm = ''
        if PWMmatch2Dict.has_key(ID):
            for middle in PWMmatch1Dict[ID].keys():
                if PWMmatch1Dict[ID][middle].has_key('+') and PWMmatch1Dict[ID][middle].has_key('-'):
                    strand = '+-'
                else:
                    strand = PWMmatch1Dict[ID][middle].keys()[0]
                for middle2 in PWMmatch2Dict[ID].keys():
                    if PWMmatch2Dict[ID][middle2].has_key('+') and PWMmatch2Dict[ID][middle2].has_key('-'):
                        strand2 = '+-'
                    else:
                        strand2 = PWMmatch2Dict[ID][middle2].keys()[0]
                    if doSS:
                        if strand2 != strand and strand2 != '+-' and strand != '+-':
                            continue
                    distance = math.fabs(middle2 - middle)
                    if distance < ND:
                        ND = distance
                        NDm = middle
            if ND == RegionLenDict[ID] + 1:
                center = PWMmatch1Dict[ID].keys()[0]
                for middle in PWMmatch1Dict[ID].keys():
                    if PWMmatch1Dict[ID][middle].has_key('+') and PWMmatch1Dict[ID][middle].has_key('-'):
                        strand = '+-'
                    else:
                        strand = PWMmatch1Dict[ID][middle].keys()[0]
                    outline = ID + '\t' + str(i) + '\t' + str(middle - center) + '\t' + 'nan' + '\t' + strand + '\t' + 'nan'
                    outfile.write(outline + '\n')
                continue
            center = NDm
            if PWMmatch1Dict[ID][center].has_key('+') and PWMmatch1Dict[ID][center].has_key('-'):
                centerstrand = '+-'
            else:
                centerstrand = PWMmatch1Dict[ID][center].keys()[0]
            for middle in PWMmatch1Dict[ID].keys():
                HasSameStrand = False
                if PWMmatch1Dict[ID][middle].has_key('+') and PWMmatch1Dict[ID][middle].has_key('-'):
                    strand = '+-'
                else:
                    strand = PWMmatch1Dict[ID][middle].keys()[0]
                for middle2 in PWMmatch2Dict[ID].keys():
                    if PWMmatch2Dict[ID][middle2].has_key('+') and PWMmatch2Dict[ID][middle2].has_key('-'):
                        strand2 = '+-'
                    else:
                        strand2 = PWMmatch2Dict[ID][middle2].keys()[0]
                    if doSS:
                        if strand2 != strand and strand2 != '+-' and strand != '+-':
                            continue
                    if centerstrand == '-':
                        outline = ID + '\t' + str(i) + '\t' + str(-(middle - center)) + '\t' + str(-(middle2 - center)) + '\t' + strand + '\t' + strand2
                        outfile.write(outline + '\n')
                    else:
                        outline = ID + '\t' + str(i) + '\t' + str(middle - center) + '\t' + str(middle2 - center) + '\t' + strand + '\t' + strand2
                        outfile.write(outline + '\n')
                    HasSameStrand = True
                if HasSameStrand:
                    pass
                else:
                    if centerstrand == '-':
                        outline = ID + '\t' + str(i) + '\t' + str(-(middle - center)) + '\t' + 'nan' + '\t' + strand + '\t' + 'nan'
                        outfile.write(outline + '\n')
                    else:
                        outline = ID + '\t' + str(i) + '\t' + str(middle - center) + '\t' + 'nan' + '\t' + strand + '\t' + 'nan'
                        outfile.write(outline + '\n')
        else:
            center = PWMmatch1Dict[ID].keys()[0]
            for middle in PWMmatch1Dict[ID].keys():
                if PWMmatch1Dict[ID][middle].has_key('+') and PWMmatch1Dict[ID][middle].has_key('-'):
                    strand = '+-'
                else:
                    strand = PWMmatch1Dict[ID][middle].keys()[0]
                outline = ID + '\t' + str(i) + '\t' + str(middle - center) + '\t' + 'nan' + '\t' + strand + '\t' + 'nan'
                outfile.write(outline + '\n')

    outfile.close()

run()

