##################################
#                                #
# Last modified 2017/03/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection
from matplotlib.cbook import get_sample_data
import random

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s regions chrFieldID leftFieldID radius JASPAR.meme ame.txt fimo.txt wig weblogo outprefix' % sys.argv[0]
        print '\tNote: the fimo.txt file can be .gz or .bz2'
        sys.exit(1)

    regions = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    leftFieldID = int(sys.argv[3])
    radius = int(sys.argv[4])
    JASPAR = sys.argv[5]
    AME = sys.argv[6]
    FIMO = sys.argv[7]
    wig = sys.argv[8]
    logoPath = sys.argv[9]
    outprefix = sys.argv[10]
    tempfile = outprefix + '.temp'

    AMEDict = {}
    
    print 'parsing AME output'

    linelist = open(AME)
    InMotifList = False
    for line in linelist:
        if line.startswith('Motif p-values'):
            InMotifList = True
            continue
        if not InMotifList:
            continue
        if line.strip() == '':
            continue
        motifID = line.strip().split('exact test p-value of motif ')[1].split(' ')[0]
        pval = line.strip().split(')')[0].split('Corrected p-value: ')[1]
        print motifID, pval
        AMEDict[motifID] = pval

    print 'finished parsing AME output'

    JASPARDict = {}

    print 'parsing JASPAR file'

    linelist = open(JASPAR)
    InMotif = False
    for line in linelist:
        if line.startswith('ALPHABET'):
            alphabet = line.strip('').split('=')[1].strip()
            print alphabet
            continue
        if line.strip() == '':
            continue
        if line.startswith('MOTIF '):
            motifID = line.strip().split(' ')[1]
            TF = line.strip().split(' ')[2].replace(':','_')
            if AMEDict.has_key(motifID):
                InMotif = True
                JASPARDict[motifID] = {}
                JASPARDict[motifID]['mot'] = []
                JASPARDict[motifID]['coverage'] = {}
                JASPARDict[motifID]['instances'] = 0
                JASPARDict[motifID]['TF'] = TF
            continue
        if line.startswith('URL '):
            InMotif = False
            continue
        if InMotif:
           if line.startswith('letter-probability'):
               continue
           newline = line
           while '  ' in newline:
               newline = newline.replace('  ',' ')
           fields = newline.strip().split(' ')
           positions = []
           i = 0
           for p in fields:
               positions.append((alphabet[i],float(p)))
               i+=1
           JASPARDict[motifID]['mot'].append(tuple(positions))
           for i in range(0 - radius,len(JASPARDict[motifID]['mot']) + radius):
               JASPARDict[motifID]['coverage'][i] = 0

    print 'finished parsing JASPAR file'

    print 'parsing regions'

    RegionCoverageDict = {}
    CoverageDict = {}

    linelist = open(regions)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        left = int(fields[leftFieldID])
        right = int(fields[leftFieldID+1])
        if CoverageDict.has_key(chr):
            pass
        else:
            CoverageDict[chr] = {}
            RegionCoverageDict[chr] = {}
        for i in range(left-radius,right+radius):
            CoverageDict[chr][i] = 0
        for i in range(left,right):
            RegionCoverageDict[chr][i] = 0

    print 'finished parsing regions'

    print 'parsing wig'

    linelist = open(wig)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if CoverageDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        score = float(fields[3])
        for i in range(left,right):
            if CoverageDict[chr].has_key(i):
                CoverageDict[chr][i] = score

    print 'finished parsing wig'

    print 'parsing FIMO'

    if FIMO.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + FIMO
    elif FIMO.endswith('.gz'):
        cmd = 'gunzip -c ' + FIMO
    else:
        cmd = 'cat ' + FIMO
    p = os.popen(cmd, "r")

    line = 'line'

    c = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        c += 1
        if c % 5000000 == 0:
            print str(c/1000000) + 'M lines processed in FIMO file'
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        motifID = fields[0]
        if JASPARDict.has_key(motifID):
            JASPARDict[motifID]['instances'] += 1
        else:
            continue
        chr = fields[1]
        if CoverageDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[2])
        right = int(fields[3])
        if RegionCoverageDict[chr].has_key(left) and RegionCoverageDict[chr].has_key(right):
            pass
        else:
            continue
        strand = fields[4]
        if strand == '+':
            for i in range(left-radius,right+radius):
                JASPARDict[motifID]['coverage'][i-left] += CoverageDict[chr][i]
        if strand == '-':
            k=0
            for i in range(left-radius,right+radius):
                k+=1
                JASPARDict[motifID]['coverage'][i-left] += CoverageDict[chr][right+radius-k]
        
    print 'finished parsing FIMO'
    
    for motifID in JASPARDict.keys():
        TF = JASPARDict[motifID]['TF']
        print 'printing motif logo', motifID, TF
        PWMPool={}
        j=0
        Nseq = 1000
        for motline in JASPARDict[motifID]['mot']:
            PWMPool[j]=[]
            for (letter,freq) in motline:
                for i in range(int(Nseq*freq)):
                    PWMPool[j].append(letter)
            j+=1
        motifLength = j
        seqfile = open(tempfile,'w')
        for i in range(Nseq):
            sequence=''
            for j in range(motifLength):
                sampledletter = random.sample(PWMPool[j],1)[0]
                sequence += sampledletter
            seqfile.write('>seq' + str(i) + '\n')
            seqfile.write(sequence + '\n')
        seqfile.close()
        logofile = outprefix + '.' + motifID + '-' + TF + '.logo'
        cmd = logoPath + ' -f ' + tempfile + ' -F PNG -c ' + '-a -Y -n -M -k 1 -o ' + logofile
        contents = os.system(cmd)
        os.remove(tempfile)

        X = []
        Y = []
        TF = JASPARDict[motifID]['TF']
        counts = JASPARDict[motifID]['instances']
        outfilename = outprefix + '.' + motifID + '-' + TF + '.mean_coverage.txt'
        outfile = open(outfilename,'w')
        outline = '#Pos\tcoverage'
        outfile.write(outline + '\n')
        positions = JASPARDict[motifID]['coverage'].keys()
        positions.sort()
        for pos in positions:
            X.append(pos)
            if counts == 0:
                normScore = 0
            else:
                normScore = JASPARDict[motifID]['coverage'][pos]/counts
            Y.append(normScore)
            outline = str(pos) + '\t' + str(normScore)
            outfile.write(outline + '\n')
        outfile.close()

        outfilename = outprefix + '.' + motifID + '-' + TF + '.mean_coverage.png'

        rect = 0.05,0.25,0.9,0.7
        fig = figure(figsize=(20, 10))

        im = plt.imread(logofile + '.png')

#        1.67,7.33,9.5

        AR = (motifLength + 0.0)/len(Y)*0.9
        ZP = (radius + 0.0)/len(Y)*0.9

#        print motifID, TF, motifLength, AR, ZP, ZP - 1.67*AR/7.33, AR*9.5/7.33
        
        rect2 = (0.05 + ZP - 1.67*AR/7.33),0.01,(AR*9.5/7.33),0.21
        newax = fig.add_axes(rect2)
        newax.imshow(im, aspect='auto')
        newax.axis('off')

        ax = fig.add_axes(rect)
        lowerlimitX=min(X)
        upperlimitX=max(X)
        lowerlimitY=min(Y)
        upperlimitY=max(Y)
        ax.plot(X,Y,'', linewidth=2)
        ax.set_title('')
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xlim(lowerlimitX,upperlimitX)
        ax.set_ylim(lowerlimitY,upperlimitY)
        xticks=ax.get_xticks()
        yticks=ax.get_yticks()
        xticklabels=[]
        yticklabels=[]
        for i in xticks:
            xticklabels.append(str(int(i)))
        for i in yticks:
            yticklabels.append(str(int(i)))
        ax.set_xticklabels(xticklabels,size=20,weight='bold')
        ax.set_yticklabels(yticklabels,size=20,weight='bold')

#        ax.axis('off')

        savefig(outfilename)

run()

