##################################
#                                #
# Last modified 2019/03/29       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection

def run():

    if len(sys.argv) < 15:
        print 'usage: python %s matrix.bgz chr1 left1 right1 chr2 left2 right2 window(bp) x_pixel_size y_pixel_size min max colorscheme width(inches,dpi) outfileprefix [-saveps]' % sys.argv[0]
        print '\tNote: the input file is assumed to be in a .bgz, tabix-indexed format'
        print '\tNote: the window can be larger than the window at which the matrix was run, but it has to be a multiple of it; then the sum of the neighboring windows will be taken'
        print '\tNote: http://matplotlib.org/examples/color/colormaps_reference.html'
        sys.exit(1)
    
    input = sys.argv[1]
    chr1 = sys.argv[2]
    RL1 = int(sys.argv[3])
    RR1 = int(sys.argv[4])
    chr2 = sys.argv[5]
    RL2 = int(sys.argv[6])
    RR2 = int(sys.argv[7])
    window = int(sys.argv[8])
    xps = float(sys.argv[9])
    yps = float(sys.argv[10])
    minS = float(sys.argv[11])
    maxS = float(sys.argv[12])
    cscheme = sys.argv[13]
    (inches,DP) = sys.argv[14].split(',')
    inches = float(inches)
    DP = int(DP)
    outprefix = sys.argv[15]

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    DataMatrix_preliminary = {}

    print '.............'

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd '+ input
    elif input.endswith('.gz') or input.endswith('.bgz'):
        cmd = 'zcat ' + input
    elif input.endswith('.gz') or input.endswith('.bgz'):
        cmd = 'cat ' + input
    print cmd
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        i+=1
        if i % 1000000 == 0:
            print i/1000000, 'M lines processed'
        fields = line.strip().split('\t')
        if fields[0] == chr1 and fields[2] == chr2:
            pos1 = int(fields[1])
            pos2 = int(fields[3])
        elif fields[0] == chr2 and fields[2] == chr1:
            pos2 = int(fields[1])
            pos1 = int(fields[3])
        else:
            continue
        pos1 = pos1 - pos1 % window
        pos2 = pos2 - pos2 % window
        if (pos1 + window) >= RL1 and (pos1 + window) <= RR1 and (pos2 + window) >= RL2 and (pos2 + window) <= RR2:
            pass
        else:
            continue
        score = float(fields[4])
        if DataMatrix_preliminary.has_key(pos1):
            pass
        else:
            DataMatrix_preliminary[pos1] = {}
        if DataMatrix_preliminary.has_key(pos2):
            pass
        else:
            DataMatrix_preliminary[pos2] = {}
        if DataMatrix_preliminary[pos1].has_key(pos2):
            pass
        else:
            DataMatrix_preliminary[pos1][pos2] = 0
        if DataMatrix_preliminary[pos2].has_key(pos1):
            pass
        else:
            DataMatrix_preliminary[pos2][pos1] = 0
        DataMatrix_preliminary[pos1][pos2] += score
        DataMatrix_preliminary[pos2][pos1] += score

    DataMatrix = []

    RLpos1 = RL1 - RL1 % window
    RRpos1 = RR1 - RR1 % window
    RLpos2 = RL2 - RL2 % window
    RRpos2 = RR2 - RR2 % window

    for i in range(RLpos1,RRpos1,window):
        row = []
        for j in range(RLpos2,RRpos2,window):
            if DataMatrix_preliminary.has_key(i) and DataMatrix_preliminary[i].has_key(j):
                v = DataMatrix_preliminary[i][j]
            else:
                v = 0
            if v < minS:
                v = minS
            if v > maxS:
                v = maxS
            v = (v - minS)/(maxS-minS)
            row.append(v)
        DataMatrix.append(row)

    DataMatrix.reverse()

    print DataMatrix

    NRows = len(DataMatrix)
    NColumns = len(DataMatrix[0])

    Height = NRows*yps
    Width = NColumns*xps

    print NRows, yps, Height
    print NColumns, xps, Width
    print inches, Height/Width, inches*(Height/Width)

    rect = 0,0,1,1
#    rect = 0.05,0.05,0.9,0.9
    # fig = figure(figsize=(80, 20),dpi=100)
    fig = figure(figsize=(inches, inches*(Height/Width)),dpi=DP)
    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)

#    labels = []
#    for i in range(RL,RR,int((RR-RL)/10.)):
#        L = 
#    ax.set_xticklabels(labels)
#    ax.set_yticklabels(labels)

#    ax.pcolor(DataMatrix, cmap=plt.cm.Blues, alpha=0.8)
    ax.pcolor(DataMatrix, vmin=0, vmax=1, cmap=cscheme)

    savefig(outprefix + '.' + chr1 + '_' + str(RL1) + '-' + str(RR1) + '.' + chr2 + '_' + str(RL2) + '-' + str(RR2) + '.png')

    if doPostScript:
        savefig(outprefix + '.' + chr1 + '_' + str(RL1) + '-' + str(RR1) + '.' + chr2 + '_' + str(RL2) + '-' + str(RR2) + '.eps', format='eps')
   
run()
