##################################
#                                #
# Last modified 2017/08/23       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import matplotlib
matplotlib.use('Agg')
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.path as mpath
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection

def label(xy, text):
    y = xy[1] + 0.05
    plt.text(xy[0], y, text, ha="center", family='sans-serif', size=14)

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s config PFAM27-A.tab list_of_proteins fieldID bar_height figure_wdith outputfileprefix [-maxE value]' % sys.argv[0]
        print '\tconfig file format'
        print '\tdomain <tab> color'
        print '\tThere should be a "protein" entry if you want something different than the default color'
        print '\tThe proteins will be plotted in the order listed'
        print '\tbar height and figure width in inches'
        print '\tNote: the script will collapse and then replace with tabs all spaces in the list_of_proteins file'
        print '\tNamed colors: https://matplotlib.org/examples/color/named_colors.html'
        sys.exit(1)

    config = sys.argv[1]
    PFAM = sys.argv[2]
    listofproteins = sys.argv[3]
    fieldID = int(sys.argv[4])
    BH = float(sys.argv[5])
    FW = float(sys.argv[6])
    outfileprefix = sys.argv[7]

    doMaxE = False
    if '-maxE' in sys.argv:
        doMaxE = True
        maxE = float(sys.argv[sys.argv.index('-maxE') + 1])
        print 'max E-value set to', maxE

    ColorDict = {}

    linelist = open(config)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        domain = fields[0]
        c = fields[1]
        ColorDict[domain] = c

    ProteinDict = {}

    linelist = open(listofproteins)
    for line in linelist:
        if line.startswith('#'):
            continue
        newline = line.strip()
        while '  ' in newline:
            newline = newline.replace('  ',' ')
        newline = newline.replace(' ','\t')
        fields = newline.split('\t')
        P = fields[fieldID]
        ProteinDict[P] = {}
        ProteinDict[P]['domains'] = []

    maxLength = 0

    Plist = []
    Dlist = []

    linelist = open(PFAM)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.split('\t')
        P = fields[0]
        if ProteinDict.has_key(P):
            pass
        else:
            continue
        E = float(fields[6])
        if doMaxE:
            if E > maxE:
                continue
        domain = fields[2]
        L = int(fields[1])
        if L > maxLength:
            maxLength = L
        ProteinDict[P]['length'] = L
        left = int(fields[10])
        right = int(fields[11])
#        c = ColorDict[domain]
        if ColorDict.has_key(domain):
            c = ColorDict[domain]
        else:
            ColorDict[domain] = 'k'
            c = 'k'
        ProteinDict[P]['domains'].append((domain,left,right,c))
        Plist.append(P)
        Dlist.append((domain,c))

    Plist = list(Set(Plist))
    Dlist = list(Set(Dlist))

    print 'plotting proteins'

    fig = plt.figure()
    fig.set_size_inches(FW,BH*len(ProteinDict.keys()))
    ax = fig.add_subplot(111)

    step = 1./(len(ProteinDict.keys()) + 1)
    print step
    pos = 1 - 0.5*step

    for protein in Plist:
#        print protein, ProteinDict[protein]
        (pstart,pend,c) = (0,ProteinDict[protein]['length'],ColorDict['protein'])
        length = 0.9*(pend/(maxLength + 0.0))
        print protein, pos, length, step/1.5, c
        rect = matplotlib.patches.Rectangle((0.05,pos), length, 0.66*step, color=c, linestyle = 'solid', edgecolor = 'black')
        ax.add_patch(rect)
        for (domain,dstart,dend,c) in ProteinDict[protein]['domains']:
            if max(dstart,dend) > maxLength:
               print 'domain outside of maximal protein length range, exiting'
               sys.exit(1)
            if max(dstart,dend) > pend:
               print 'domain outside of protein length range, exiting'
               sys.exit(1)
            S = 0.9*(dstart/(maxLength + 0.0))
            E = 0.9*(dend/(maxLength + 0.0))
            rect = matplotlib.patches.Rectangle((0.05 + S,pos), E-S, 0.66*step, color=c)
            ax.add_patch(rect)
        rect = matplotlib.patches.Rectangle((0.05,pos), length, 0.66*step, fill=None, color = 'black')
        ax.add_patch(rect)
        if '|' in protein:
            ptext = protein.split('|')[1]
        else:
            ptext = protein
        plt.text(0.05, pos + (2.1/3.)*step, ptext, fontsize=BH*12)
        pos -= step

    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.axis('equal')
    plt.axis('off')
    plt.savefig(outfileprefix + '.proteins.png', dpi=200)
    plt.savefig(outfileprefix + '.proteins.eps', format='eps')

    print 'plotting domains'

    Dlist = list(Set(Dlist))

    fig = plt.figure()
    fig.set_size_inches(FW,BH*len(Dlist))
    ax = fig.add_subplot(111)

    step = 1./(len(Dlist) + 1)
    pos = 1 - 0.5*step
    for (domain,c) in Dlist:
        rect = matplotlib.patches.Rectangle((0.1,pos), 0.5, 0.66*step, color=c, linestyle = 'solid', edgecolor = 'black')
        ax.add_patch(rect)
        rect = matplotlib.patches.Rectangle((0.1,pos), 0.5, 0.66*step, fill=None, color = 'black')
        ax.add_patch(rect)
        plt.text(0.625, pos + 0.3*step, domain, fontsize=BH*12)
        pos -= step

    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.axis('equal')
    plt.axis('off')
    plt.savefig(outfileprefix + '.domains.png', dpi=200)
    plt.savefig(outfileprefix + '.domains.eps', format='eps')
    
run()
