##################################
#                                #
# Last modified 2017/03/22       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection
from matplotlib.cbook import get_sample_data
import random

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s config radius JASPAR.meme fimo.txt wig weblogo [-verbose]' % sys.argv[0]
        print '\tNote: config file format: regions <tab> chrFieldID <tab> leftFieldID <tab> rightFieldID  <tab> ame.txt <tab> outfile_prefix'
        print '\tNote: the fimo.txt file can be .gz or .bz2'
        sys.exit(1)

    config = sys.argv[1]
    radius = int(sys.argv[2])
    JASPAR = sys.argv[3]
    FIMO = sys.argv[4]
    wig = sys.argv[5]
    logoPath = sys.argv[6]

    print 'parsing regions'

    doVerbose = False
    if '-verbose' in sys.argv:
        doVerbose = True

    JASPARDict = {}
    CoverageDict = {}
    RegionCoverageDict = {}
    AMEDict = {}

    total_regions = 0
    TBP = 0

    linelist = open(config)
    for cline in linelist:
        if cline.startswith('#'):
            continue
        cfields = cline.strip().split('\t')
        regions = cfields[0]
        chrFieldID = int(cfields[1])
        leftFieldID = int(cfields[2])
        rightFieldID = int(cfields[3])

        linelist = open(regions)
        if doVerbose:
            print regions
        for line in linelist:
            if line.startswith('#'):
                continue
            total_regions += 1
            fields = line.strip().split('\t')
            chr = fields[chrFieldID]
            left = int(fields[leftFieldID])
            right = int(fields[rightFieldID])
            if CoverageDict.has_key(chr):
                pass
            else:
                CoverageDict[chr] = {}
            for i in range(left-radius,right+radius):
                CoverageDict[chr][i] = 0

    for chr in CoverageDict:
        TBP += len(CoverageDict[chr].keys())

    print 'total regions processed:', total_regions
    print 'total bases considered:', TBP

    print 'finished parsing regions'

    print 'parsing wig'

    linelist = open(wig)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if CoverageDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        score = float(fields[3])
        for i in range(left,right):
            if CoverageDict[chr].has_key(i):
                CoverageDict[chr][i] = score

    print 'finished parsing wig'


    Done = 0

    linelist = open(config)
    for cline in linelist:
        if cline.startswith('#'):
            continue
        Done += 1
        if Done % 20 == 0:
            print Done, 'items processed'
        cfields = cline.strip().split('\t')
        regions = cfields[0]
        chrFieldID = int(cfields[1])
        leftFieldID = int(cfields[2])
        rightFieldID = int(cfields[3])
        AME = cfields[4]
        outprefix = cfields[5]

        tempfile = outprefix + '.temp'

        RegionCoverageDict.clear()
        RegionCoverageDict = {}

        linelist = open(regions)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            chr = fields[chrFieldID]
            left = int(fields[leftFieldID])
            right = int(fields[leftFieldID+1])
            if RegionCoverageDict.has_key(chr):
                pass
            else:
                RegionCoverageDict[chr] = {}
                RegionCoverageDict[chr].clear()
            for i in range(left,right):
                RegionCoverageDict[chr][i] = 0

        AMEDict.clear()
        AMEDict = {}
    
        if doVerbose:
            print 'parsing AME output', AME

        linelist = open(AME)
        InMotifList = False
        for line in linelist:
            if line.startswith('Motif p-values'):
                InMotifList = True
                continue
            if not InMotifList:
                continue
            if line.strip() == '':
                continue
            motifID = line.strip().split('exact test p-value of motif ')[1].split(' ')[0]
            pval = line.strip().split(')')[0].split('Corrected p-value: ')[1]
#            print motifID, pval
            AMEDict[motifID] = pval

        if doVerbose:
            print 'finished parsing AME output'

        if doVerbose:
            print AMEDict.keys()

        JASPARDict.clear()
        JASPARDict = {}

        if doVerbose:
            print 'parsing JASPAR file'

        linelist = open(JASPAR)
        InMotif = False
        for line in linelist:
            if line.startswith('ALPHABET'):
                alphabet = line.strip('').split('=')[1].strip()
                continue
            if line.strip() == '':
                continue
            if line.startswith('MOTIF '):
                motifID = line.strip().split(' ')[1]
                TF = line.strip().split(' ')[2].replace(':','_')
                if AMEDict.has_key(motifID):
                    InMotif = True
                    JASPARDict[motifID] = {}
                    JASPARDict[motifID]['mot'] = []
                    JASPARDict[motifID]['coverage'] = {}
                    JASPARDict[motifID]['instances'] = 0
                    JASPARDict[motifID]['TF'] = TF
                continue
            if line.startswith('URL '):
                InMotif = False
                continue
            if InMotif:
               if line.startswith('letter-probability'):
                   continue
               newline = line
               while '  ' in newline:
                   newline = newline.replace('  ',' ')
               fields = newline.strip().split(' ')
               positions = []
               i = 0
               for p in fields:
                   positions.append((alphabet[i],float(p)))
                   i+=1
               JASPARDict[motifID]['mot'].append(tuple(positions))
               for i in range(0 - radius,len(JASPARDict[motifID]['mot']) + radius):
                   JASPARDict[motifID]['coverage'][i] = 0

        if doVerbose:
            print 'finished parsing JASPAR file'

        if doVerbose:
            print 'parsing FIMO'

        if FIMO.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + FIMO
        elif FIMO.endswith('.gz'):
            cmd = 'gunzip -c ' + FIMO
        else:
            cmd = 'cat ' + FIMO
        p = os.popen(cmd, "r")

        line = 'line'

        c = 0
        while line != '':
            line = p.readline()
            if line == '':
                break
            c += 1
            if doVerbose:
                if c % 10000000 == 0:
                    print str(c/1000000) + 'M lines processed in FIMO file'
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            motifID = fields[0]
            if JASPARDict.has_key(motifID):
                JASPARDict[motifID]['instances'] += 1
            else:
                continue
            chr = fields[1]
            if CoverageDict.has_key(chr):
                pass
            else:
                continue
            left = int(fields[2])
            right = int(fields[3])
            if RegionCoverageDict.has_key(chr):
                pass
            else:
                continue
            if RegionCoverageDict[chr].has_key(left) and RegionCoverageDict[chr].has_key(right):
                pass
            else:
                continue
            strand = fields[4]
            if strand == '+':
                for i in range(left-radius,right+radius):
                    JASPARDict[motifID]['coverage'][i-left] += CoverageDict[chr][i]
            if strand == '-':
                k=0
                for i in range(left-radius,right+radius):
                    k+=1
                    JASPARDict[motifID]['coverage'][i-left] += CoverageDict[chr][right+radius-k]
        
        if doVerbose:
            print 'finished parsing FIMO'
    
        for motifID in JASPARDict.keys():
            TF = JASPARDict[motifID]['TF']
            if doVerbose:
                print 'printing motif logo', motifID, TF
            PWMPool={}
            j=0
            Nseq = 1000
            for motline in JASPARDict[motifID]['mot']:
                PWMPool[j]=[]
                for (letter,freq) in motline:
                    for i in range(int(Nseq*freq)):
                        PWMPool[j].append(letter)
                j+=1
            motifLength = j
            seqfile = open(tempfile,'w')
            for i in range(Nseq):
                sequence=''
                for j in range(motifLength):
                    sampledletter = random.sample(PWMPool[j],1)[0]
                    sequence += sampledletter
                seqfile.write('>seq' + str(i) + '\n')
                seqfile.write(sequence + '\n')
            seqfile.close()
            logofile = outprefix + '.' + motifID + '-' + TF + '.logo'
            cmd = logoPath + ' -f ' + tempfile + ' -F PNG -c ' + '-a -Y -n -M -k 1 -o ' + logofile
            contents = os.system(cmd)
            os.remove(tempfile)
    
            X = []
            Y = []
            TF = JASPARDict[motifID]['TF']
            counts = JASPARDict[motifID]['instances']
            outfilename = outprefix + '.' + motifID + '-' + TF + '.mean_coverage.txt'
            outfile = open(outfilename,'w')
            outline = '#Pos\tcoverage'
            outfile.write(outline + '\n')
            positions = JASPARDict[motifID]['coverage'].keys()
            positions.sort()
            for pos in positions:
                X.append(pos)
                if counts > 0.0:
                    normScore = JASPARDict[motifID]['coverage'][pos]/counts
                else:
                    normScore = 0.0
                Y.append(normScore)
                outline = str(pos) + '\t' + str(normScore)
                outfile.write(outline + '\n')
            outfile.close()
    
            outfilename = outprefix + '.' + motifID + '-' + TF + '.mean_coverage.png'
    
            rect = 0.05,0.25,0.9,0.7
            fig = figure(figsize=(20, 10))
    
            im = plt.imread(logofile + '.png')

#            1.67,7.33,9.5

            AR = (motifLength + 0.0)/len(Y)*0.9
            ZP = (radius + 0.0)/len(Y)*0.9
    
            rect2 = (0.05 + ZP - 1.67*AR/7.33),0.01,(AR*9.5/7.33),0.21
            newax = fig.add_axes(rect2)
            newax.imshow(im, aspect='auto')
            newax.axis('off')
    
            ax = fig.add_axes(rect)
            lowerlimitX=min(X)
            upperlimitX=max(X)
            lowerlimitY=min(Y)
            upperlimitY=max(Y)
            ax.plot(X,Y,'', linewidth=2)
            ax.set_title('')
            ax.set_xlabel('')
            ax.set_ylabel('')
            ax.set_xlim(lowerlimitX,upperlimitX)
            ax.set_ylim(lowerlimitY,upperlimitY)
            xticks=ax.get_xticks()
            yticks=ax.get_yticks()
            xticklabels=[]
            yticklabels=[]
            for i in xticks:
                xticklabels.append(str(int(i)))
            for i in yticks:
                yticklabels.append(str(int(i)))
            ax.set_xticklabels(xticklabels,size=20,weight='bold')
            ax.set_yticklabels(yticklabels,size=20,weight='bold')

#            ax.axis('off')

            savefig(outfilename)

run()

