##################################
#                                #
# Last modified 2021/03/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import gzip
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import shapefile as shp
import pandas as pd
import geopandas as gpd
from pylab import *
# from pylab import *
# from matplotlib.patches import Circle, RegularPolygon
# from matplotlib.collections import PatchCollection

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s datafile IDfieldID dataFieldID figsize(x,y) shapefile SFfieldID outfile [-color c] [-saveps]' % sys.argv[0]
        sys.exit(1)
    
    input = sys.argv[1]
    IDfieldID = int(sys.argv[2])
    datafieldID = int(sys.argv[3])
    figsizeX = int(sys.argv[4].split(',')[0])
    figsizeY = int(sys.argv[4].split(',')[1])
    figsize = (figsizeX,figsizeY)
    shapefile = sys.argv[5]
    SFfieldID = int(sys.argv[6])
    outfilename = sys.argv[7]

    resize = 1

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    color = 'Reds'
    if '-color' in sys.argv:
        color = sys.argv[sys.argv.index('-color')+1]

    IDs = []
    values = []
    if input.endswith('.gz'):
        lineslist = gzip.open(input)
    else:
        lineslist = open(input)
    i=0
    for line in lineslist:
        i+=1
        if i % 100000 == 0:
            print i
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        IDs.append(fields[IDfieldID])
        values.append(float(fields[datafieldID]))

    sf = shp.Reader(shapefile)

    SFDict = {}
    for i in range(len(sf.shapes())):
        SFID = sf.records()[i][SFfieldID]
        SFDict[SFID] = i

    rect = 0.10,0.10,0.8,0.8
    fig = figure(figsize=figsize)
#    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)

#    x1 = [-1,-1,10,10,-1]; y1 = [-1,10,10,-1,-1]
#    x2 = [21,21,29,29,21]; y2 = [21,29,29,21,21]
#    shapes = [[x1,y1],[x2,y2]]
#    for shape in shapes:
#      x,y = shape
#      p.plot(x,y)
#    p.show()    

#    plt.figure()
#    for shape in sf.shapeRecords():
#        for i in range(len(shape.shape.parts)):
#            i_start = shape.shape.parts[i]
#            if i==len(shape.shape.parts)-1:
#                i_end = len(shape.shape.points)
#            else:
#                i_end = shape.shape.parts[i+1]
#            x = [i[0] for i in shape.shape.points[i_start:i_end]]
#            y = [i[1] for i in shape.shape.points[i_start:i_end]]
#            plot_map_fill(13, sf,color='y')
#            plt.plot(x,y)

#    fig, ax = plt.subplots(figsize = figsize)
#    fig.suptitle(title, fontsize=16)
    for shape in sf.shapeRecords():
        x = [i[0] for i in shape.shape.points[:]]
        y = [i[1] for i in shape.shape.points[:]]
        ax.plot(x, y, 'k')

    color_ton = {}

    if color == 1: 
        color_sq =  ['#dadaebFF','#bcbddcF0','#9e9ac8F0', '#807dbaF0','#6a51a3F0','#54278fF0']
        colors = 'Purples'
    elif color == 2: 
        color_sq =  ['#c7e9b4','#7fcdbb','#41b6c4', '#1d91c0','#225ea8','#253494']
        colors = 'YlGnBu'
    elif color == 3: 
        color_sq = ['#f7f7f7','#d9d9d9','#bdbdbd', '#969696','#636363','#252525']
        colors = 'Greys'
    elif color == 9: 
        color_sq = ['#ff0000','#ff0000','#ff0000','#ff0000','#ff0000','#ff0000']
    else:            
        color_sq = ['#ffffd4','#fee391','#fec44f', '#fe9929','#d95f0e','#993404']
        colors = 'YlOrBr'
    new_data, bins = pd.qcut(data, 6, retbins=True, labels=list(range(6)))
    color_ton = []
    for val in new_data:
        color_ton.append(color_sq[val]) 
    if color != 9:
        colors = sns.color_palette(colors, n_colors=6)
        sns.palplot(colors, 0.6);
        for i in range(6):
            print ('\n' + str(i+1) + ': ' + str(int(bins[i]))+ ' => ' + str(int(bins[i+1])-1))
            print('\n\n   1   2   3   4   5   6')
    return color_ton, bins

    for i in IDs:
        ID = IDs[i]
        V = values[i]
        SFID = SFDict[ID]
        print ID, SFID
        shape_ex = sf.shape(SFID)
        print shape_ex
        x_lon = np.zeros((len(shape_ex.points),1))
        y_lat = np.zeros((len(shape_ex.points),1))
        for ip in range(len(shape_ex.points)):
            x_lon[ip] = shape_ex.points[ip][0]
            y_lat[ip] = shape_ex.points[ip][1]
        ax.fill(x_lon,y_lat, color_ton[V])
        ax.fill(x_lon,y_lat, 'y')
        cs = ax.hexbin(A, B, gridsize=(Xgrid,Ygrid), edgecolors=EdgeColor, marginals=doMarginals, cmap=color)
#        if print_id != False:
#            x0 = np.mean(x_lon)
#            y0 = np.mean(y_lat)
#            plt.text(x0, y0, id, fontsize=10)
#    if (x_lim != None) & (y_lim != None):     
#        plt.xlim(x_lim)
#        plt.ylim(y_lim)

    savefig(outfilename)

    if doPostScript:
        savefig(outfilename + '.eps', format='eps')
   
run()
