##################################
#                                #
# Last modified 2018/11/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import os
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from sets import Set
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection
import random

def run():

    if len(sys.argv) < 9:
        print 'usage: python %s SingleMoleculeCorrelation-empirical chr left right colorscheme minScore maxScore edge_color outfile_prefix [-dotsize size] [-resize factory_pixel_size]' % sys.argv[0]
        print '\tassumed input format:'
        print '\t\t#chr	peak1_left	peak1_right	peak1_open	peak1_closed	peak1_fraction	peak2_left	peak2_right	peak2_open	peak2_closed	peak2_fraction	Fisher_test_p_val	Empirical_p-val	Max_upper_empirical_p-val	Max_lower_empirical_p-val	NMI	strand1	strand2	TSS_array_N	distance'
        sys.exit(1)

    input = sys.argv[1]
    CHR = sys.argv[2]
    leftPos = int(sys.argv[3])
    rightPos = int(sys.argv[4])
    SPcs = sys.argv[5]
    SPmin = float(sys.argv[6])
    SPmax = float(sys.argv[7])
    SPedge = sys.argv[8]
    outprefix = sys.argv[9]

    X = []
    Y = []
    M = []
    S = []

    dotsize = 1
    if '-dotsize' in sys.argv:
        dotsize = float(sys.argv[sys.argv.index('-dotsize') + 1])
        print 'will scale dots by a factor of', dotsize

    resize = 1.0
    if '-resize' in sys.argv:
        resize = float(sys.argv[sys.argv.index('-resize') + 1])
        print 'will resize a factor of', resize

    if input.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + input
    elif input.endswith('.gz'):
        cmd = 'gunzip -c ' + input
    elif input.endswith('.zip'):
        cmd = 'unzip -p ' + input
    else:
        cmd = 'cat ' + input
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if chr != CHR:
            continue
        peak1_left = int(fields[1])
        peak1_right = int(fields[2])
        peak2_left = int(fields[6])
        peak2_right = int(fields[7])
        if peak1_right < leftPos:
            continue
        if peak2_left > rightPos:
            continue
        NMI = float(fields[15])
        peak2 = (peak2_left + peak2_right)/2.0
        peak1 = (peak1_left + peak1_right)/2.0
        X.append(peak1)
        Y.append(peak2)
        M.append(NMI)
        S.append(math.fabs(NMI)*dotsize)

#    print X
#    print Y
#    print M

    rect = 0.10,0.10,0.8,0.8
    fig = figure(figsize=(20*resize, 20*resize))
    ax = fig.add_subplot(1,1,1,aspect='equal')
    ax = fig.add_axes(rect)
    lowerlimitX = leftPos
    upperlimitX = rightPos
    lowerlimitY = leftPos
    upperlimitY = rightPos
    if SPedge == 'none':
        ax.scatter(X, Y, marker='o', s=S, c=M, vmin=SPmin, vmax=SPmax, cmap=SPcs)
    else:
        ax.scatter(X, Y, marker='o', s=S, edgecolor=SPedge, c=M, vmin=SPmin, vmax=SPmax, cmap=SPcs)
    ax.set_xlim(lowerlimitX,upperlimitX)
    ax.set_ylim(lowerlimitY,upperlimitY)
    xticks = ax.get_xticks()
    yticks = ax.get_yticks()

    savefig(outprefix + '.scatter.png')
    savefig(outprefix + '.scatter.eps', format='eps')

run()

