##################################
#                                #
# Last modified 2021/03/18       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import string
import math
import gzip
import numpy as np
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import shapefile as shp
import pandas as pd
import geopandas as gpd
from pylab import *

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s datafile.cvs IDfieldID dataFieldID figsize(x,y) shapefile SFfieldIDName outfile [-min float] [-doBar] [-max float] [-color c] [-saveps]' % sys.argv[0]
        sys.exit(1)
    
    input = sys.argv[1]
    IDfieldID = int(sys.argv[2])
    datafieldID = int(sys.argv[3])
    figsizeX = int(sys.argv[4].split(',')[0])
    figsizeY = int(sys.argv[4].split(',')[1])
    figsize = (figsizeX,figsizeY)
    shapefile = sys.argv[5]
    SFfieldID = sys.argv[6]
    outfilename = sys.argv[7]

    resize = 1

    doPostScript = False
    if '-saveps' in sys.argv:
        doPostScript = True

    doBar = False
    if '-doBar' in sys.argv:
        doBar = True

    values = []

    lineslist = open(input)
    i=0
    for line in lineslist:
        fields = line.strip().split(',')
        if line.startswith('#'):
            regionID = fields[IDfieldID]
            dataID = fields[datafieldID]
            continue
        values.append(float(fields[datafieldID]))

    data = pd.read_csv(input)

    regions = gpd.read_file(shapefile)

    merged = regions.set_index(SFfieldID).join(data.set_index(regionID))
    merged = merged.reset_index()
    merged = merged.fillna(0)

    if '-min' in sys.argv:
        vmin = float(sys.argv[sys.argv.index('-min') + 1])
    else:
        vmin = min(values)
    if '-max' in sys.argv:
        vmax = float(sys.argv[sys.argv.index('-max') + 1])
    else:
        vmax = max(values)

    color = 'Reds'
    if '-color' in sys.argv:
        color = sys.argv[sys.argv.index('-color')+1]

    print 'vmin, vmax', vmin, vmax
    print 'color:', color

    fig, ax = plt.subplots(1, figsize=figsize)
    ax.axis('off')

    if doBar:
        sm = plt.cm.ScalarMappable(cmap=color, norm=plt.Normalize(vmin=vmin, vmax=vmax))
        sm._A = []
        cbar = fig.colorbar(sm)
        cbar.ax.tick_params(labelsize=20)

    merged.plot(dataID, cmap=color, norm=plt.Normalize(vmin=vmin, vmax=vmax), linewidth=0.8, ax=ax, edgecolor='0.8', figsize=figsize)

    savefig(outfilename)

    if doPostScript:
        savefig(outfilename + '.eps', format='eps')

run()
