##################################
#                                #
# Last modified 2018/10/08       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection

def run():

    if len(sys.argv) < 13:
        print 'usage: python %s matrix.bgz chr left right window(bp) tabix_path x_pixel_size y_pixel_size min max colorscheme width(inches,dpi) outfileprefix [-saveps]' % sys.argv[0]
        print '\tNote: the input file is assumed to be in a .bgz, tabix-indexed format'
        print '\tNote: the window can be larger than the window at which the matrix was run, but it has to be a multiple of it; then the sum of the neighboring windows will be taken'
        print '\tNote: http://matplotlib.org/examples/color/colormaps_reference.html'
        sys.exit(1)
    
    input = sys.argv[1]
    chr = sys.argv[2]
    RL = int(sys.argv[3])
    RR = int(sys.argv[4])
    window = int(sys.argv[5])
    tabix = sys.argv[6]
    xps = float(sys.argv[7])
    yps = float(sys.argv[8])
    minS = float(sys.argv[9])
    maxS = float(sys.argv[10])
    cscheme = sys.argv[11]
    (inches,DP) = sys.argv[12].split(',')
    inches = float(inches)
    DP = int(DP)
    outprefix = sys.argv[13]

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    DataMatrix_preliminary = {}

    cmd = tabix + ' ' + input + ' ' + chr + ':' + str(RL) + '-' + str(RR)
    print cmd
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print i
        fields = line.strip().split('\t')
        if fields[0] != fields[2]:
            continue
        pos1 = int(fields[1])
        pos2 = int(fields[3])
        pos1 = pos1 - pos1 % window
        pos2 = pos2 - pos2 % window
        score = float(fields[4])
        if DataMatrix_preliminary.has_key(pos1):
            pass
        else:
            DataMatrix_preliminary[pos1] = {}
        if DataMatrix_preliminary.has_key(pos2):
            pass
        else:
            DataMatrix_preliminary[pos2] = {}
        if DataMatrix_preliminary[pos1].has_key(pos2):
            pass
        else:
            DataMatrix_preliminary[pos1][pos2] = 0
        if DataMatrix_preliminary[pos2].has_key(pos1):
            pass
        else:
            DataMatrix_preliminary[pos2][pos1] = 0
        DataMatrix_preliminary[pos1][pos2] += score
        DataMatrix_preliminary[pos2][pos1] += score

    DataMatrix = []

    RLpos = RL - RL % window
    RRpos = RR - RR % window

    for i in range(RLpos,RRpos,window):
        row = []
        for j in range(RLpos,RRpos,window):
            if DataMatrix_preliminary.has_key(i) and DataMatrix_preliminary[i].has_key(j):
                v = DataMatrix_preliminary[i][j]
            else:
                v = 0
            if v < minS:
                v = minS
            if v > maxS:
                v = maxS
            v = (v - minS)/(maxS-minS)
            row.append(v)
        DataMatrix.append(row)

    DataMatrix.reverse()

    NRows = len(DataMatrix)
    NColumns = len(DataMatrix[0])

    Height = NRows*yps
    Width = NColumns*xps

    print NRows, yps, Height
    print NColumns, xps, Width
    print inches, Height/Width, inches*(Height/Width)

    rect = 0,0,1,1
#    rect = 0.05,0.05,0.9,0.9
    # fig = figure(figsize=(80, 20),dpi=100)
    fig = figure(figsize=(inches, inches*(Height/Width)),dpi=DP)
    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)

#    labels = []
#    for i in range(RL,RR,int((RR-RL)/10.)):
#        L = 
#    ax.set_xticklabels(labels)
#    ax.set_yticklabels(labels)

#    ax.pcolor(DataMatrix, cmap=plt.cm.Blues, alpha=0.8)
    ax.pcolor(DataMatrix, vmin=0, vmax=1, cmap=cscheme)

    savefig(outprefix + '.' + chr + '_' + str(RL) + '-' + str(RR) + '.png')

    if doPostScript:
        savefig(outprefix + '.' + chr + '_' + str(RL) + '-' + str(RR) + '.eps', format='eps')
   
run()
