##################################
#                                #
# Last modified 2019/12/02       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection

def run():

    if len(sys.argv) < 10:
        print 'usage: python %s dump1 dump2 window x_pixel_size y_pixel_size min max colorscheme width(inches,dpi) outfileprefix [-saveps] [-minPos pos] [-maxPos pos]' % sys.argv[0]
        print '\tInput format:'
        print '\t0       40000   9.722908'
        sys.exit(1)
    
    dump1 = sys.argv[1]
    dump2 = sys.argv[2]
    window = int(sys.argv[3])
    xps = float(sys.argv[4])
    yps = float(sys.argv[5])
    minS = float(sys.argv[6])
    maxS = float(sys.argv[7])
    cscheme = sys.argv[8]
    (inches,DP) = sys.argv[9].split(',')
    inches = float(inches)
    DP = int(DP)
    outprefix = sys.argv[10]

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    doMin = False
    if '-min' in sys.argv:
        HardMin = int(sys.argv[sys.argv.index('-minPos') + 1])

    doMax = False
    if '-max' in sys.argv:
        HardMax = int(sys.argv[sys.argv.index('-maxPos') + 1])

    DataMatrix1 = {}
    DataMatrix2 = {}

    cmd = 'cat ' + dump1
#    print cmd
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print dump1, i
        fields = line.strip().split('\t')
        pos1 = int(fields[0])
        pos2 = int(fields[1])
        pos1 = pos1
        pos2 = pos2
        score = float(fields[2])
        if DataMatrix1.has_key(pos1):
            pass
        else:
            DataMatrix1[pos1] = {}
        if DataMatrix1.has_key(pos2):
            pass
        else:
            DataMatrix1[pos2] = {}
        if DataMatrix1[pos1].has_key(pos2):
            pass
        else:
            DataMatrix1[pos1][pos2] = 0
        if DataMatrix1[pos2].has_key(pos1):
            pass
        else:
            DataMatrix1[pos2][pos1] = 0
        DataMatrix1[pos1][pos2] += score
        DataMatrix1[pos2][pos1] += score

    cmd = 'cat ' + dump2
#    print cmd
    p = os.popen(cmd, "r")
    line = 'line'
    i = 0
    while line != '':
        line = p.readline()
        if line == '':
            break
        if line.startswith('#'):
            continue
        i+=1
        if i % 100000 == 0:
            print dump2, i
        fields = line.strip().split('\t')
        pos1 = int(fields[0])
        pos2 = int(fields[1])
        pos1 = pos1
        pos2 = pos2
        score = float(fields[2])
        if DataMatrix2.has_key(pos1):
            pass
        else:
            DataMatrix2[pos1] = {}
        if DataMatrix2.has_key(pos2):
            pass
        else:
            DataMatrix2[pos2] = {}
        if DataMatrix2[pos1].has_key(pos2):
            pass
        else:
            DataMatrix2[pos1][pos2] = 0
        if DataMatrix2[pos2].has_key(pos1):
            pass
        else:
            DataMatrix2[pos2][pos1] = 0
        DataMatrix2[pos1][pos2] += score
        DataMatrix2[pos2][pos1] += score

    if doMin:
        minPos = HardMin
    else:
        min1 = min(DataMatrix1.keys())
        min2 = min(DataMatrix2.keys())
        minPos = min(min1,min2)

    if doMax:
        maxPos = HardMax
    else:
        max1 = max(DataMatrix1.keys())
        max2 = max(DataMatrix2.keys())
        maxPos = max(max1,max2)

    DataMatrix = []

    for i in range(minPos,maxPos,window):
        row = []
        for j in range(minPos,maxPos,window):
            if DataMatrix2.has_key(i) and DataMatrix2[i].has_key(j):
                if DataMatrix1.has_key(i) and DataMatrix1[i].has_key(j):
                    v = DataMatrix2[i][j] - DataMatrix1[i][j]
                else:
                    v = DataMatrix2[i][j]
            elif DataMatrix1.has_key(i) and DataMatrix1[i].has_key(j):
                v = 0 - DataMatrix1[i][j]
            else:
                v = 0
            if v < minS:
                v = minS
            if v > maxS:
                v = maxS
            v = (v - minS)/(maxS-minS)
            row.append(v)
        DataMatrix.append(row)

    DataMatrix.reverse()

    NRows = len(DataMatrix)
    NColumns = len(DataMatrix[0])

    Height = NRows*yps
    Width = NColumns*xps

    print NRows, yps, Height
    print NColumns, xps, Width
    print inches, Height/Width, inches*(Height/Width)

    rect = 0,0,1,1
#    rect = 0.05,0.05,0.9,0.9
    # fig = figure(figsize=(80, 20),dpi=100)
    fig = figure(figsize=(inches, inches*(Height/Width)),dpi=DP)
    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)

#    labels = []
#    for i in range(RL,RR,int((RR-RL)/10.)):
#        L = 
#    ax.set_xticklabels(labels)
#    ax.set_yticklabels(labels)

#    ax.pcolor(DataMatrix, cmap=plt.cm.Blues, alpha=0.8)
    ax.pcolor(DataMatrix, vmin=0, vmax=1, cmap=cscheme)

    savefig(outprefix + '.png')

    if doPostScript:
        savefig(outprefix + '.eps', format='eps')
   
run()
