##################################
#                                #
# Last modified 2018/04/05       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s datafile.npy x_pixel_size y_pixel_size min max colorscheme width(inches,dpi) outfile [-every Nth] [-average bp] [-window bp]' % sys.argv[0]
        print '\tNote: enter width as a comma-separated tuple of the inches and dpi; the height will be rescaled accordingly'
        print '\tNote: http://matplotlib.org/examples/color/colormaps_reference.html'
        sys.exit(1)
    
    input = sys.argv[1]
    xps = float(sys.argv[2])
    yps = float(sys.argv[3])
    minS = float(sys.argv[4])
    maxS = float(sys.argv[5])
    cscheme = sys.argv[6]
    (inches,DP) = sys.argv[7].split(',')
    inches = float(inches)
    DP = int(DP)
    outfilename = sys.argv[8]

    window=1
    averageRadius=0
    doAverage=False
    if '-average' in sys.argv:
        doAverage=True
        averageRadius=int(int(sys.argv[sys.argv.index('-average')+1])/2.0)
        print 'will average signal over', 2*averageRadius, 'bp'

    doWindow=False
    if '-window' in sys.argv:
        doWindow=True
        window=int(sys.argv[sys.argv.index('-window')+1])
        print 'will split into windows of size', window, 'bp'

    Data = np.load(input)

    print len(Data), len(Data[0])

    DataMatrix = []

    i=0
    for row in Data:
        i+=1
        newrow = []
        for score in row:
            newscore = sum(score)
            newrow.append(newscore)
        newnewrow = []
        for v in newrow:
            if v < minS:
                v = minS
            if v > maxS:
                v = maxS
            v = (v - minS)/(maxS-minS)
            newnewrow.append(v)
        finalrow = []
        if doWindow:
            for i in range(0, len(newnewrow) - window, window):
                score = sum(newnewrow[i:i + window])/(window + 0.0)
                finalrow.append(score)
        elif doAverage:
            for i in range(0 + averageRadius, len(newnewrow) - averageRadius):
                score = sum(newnewrow[i - averageRadius:i + averageRadius])/(2.0*averageRadius)
                finalrow.append(score)
        else:
            for i in range(len(newnewrow)):
                finalrow.append(newnewrow[i])
        DataMatrix.append(finalrow)

    print len(DataMatrix), len(DataMatrix[0])

    DataMatrix.reverse()

    NRows = len(DataMatrix)
    NColumns = len(DataMatrix[0])

    Height = NRows*yps
    Width = NColumns*xps

    print NRows, yps, Height
    print NColumns, xps, Width
    print inches, Height/Width, inches*(Height/Width)

    rect = 0,0,1,1
    # fig = figure(figsize=(80, 20),dpi=100)
    fig = figure(figsize=(inches, inches*(Height/Width)),dpi=DP)
    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)

#    ax.pcolor(DataMatrix, cmap=plt.cm.Blues, alpha=0.8)
    ax.pcolor(DataMatrix, vmin=0, vmax=1, cmap=cscheme)

    savefig(outfilename)
   
run()
