##################################
#                                #
# Last modified 2018/12/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats import beta
from scipy.stats import binom
from scipy.stats import norm
import random
import os
import math
from sets import Set
from sklearn.metrics import normalized_mutual_info_score as NMIS
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster


def run():

    if len(sys.argv) < 12:
        print 'usage: python %s methylation_reads_all.tsv regions.bed chrFieldID leftField rightField strandFieldID peaks chrFieldID posFieldID radius tabix_location outfile [-unstranded] [-label fieldID] [-cluster window colorscheme minScore maxScore color|none resize]' % sys.argv[0]
        print '\Note: the script assumes Tombo 1.3 probabilities, a tabix indexed reads file, and uses a beta distribution prior of (10,10) by default'
        sys.exit(1)

    reads = sys.argv[1]
    regions = sys.argv[2]
    regionchrFieldID = int(sys.argv[3])
    regionleftFieldID = int(sys.argv[4])
    regionrightFieldID = int(sys.argv[5])
    regionstrandFieldID = int(sys.argv[6])
    peaks = sys.argv[7]
    chrFieldID = int(sys.argv[8])
    posFieldID = int(sys.argv[9])
    R = int(sys.argv[10])
    tabix = sys.argv[11]
    outprefix = sys.argv[12]

    doLabel = False
    if '-label' in sys.argv:
        doLabel = True
        labelFieldID = int(sys.argv[sys.argv.index('-label') + 1])

    doCluster = False
    if '-cluster' in sys.argv:
        doCluster = True
        window = int(sys.argv[sys.argv.index('-cluster') + 1])
        SPcs = sys.argv[sys.argv.index('-cluster') + 2]
        SPmin = float(sys.argv[sys.argv.index('-cluster') + 3])
        SPmax = float(sys.argv[sys.argv.index('-cluster') + 4])
        SPedge = sys.argv[sys.argv.index('-cluster') + 5]
        resize = float(sys.argv[sys.argv.index('-cluster') + 6])

    doUnstranded = False
    if '-unstranded' in sys.argv:
        doUnstranded = True

    alph = 10
    bet = 10
    PSS = 100

    PeakDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    peakline = 'line'
    while peakline != '':
        peakline = P.readline().strip()
        if peakline == '':
            break
        if peakline.startswith('#'):
            continue
        fields = peakline.strip().split('\t')
        chr = fields[chrFieldID]
        pos = int(fields[posFieldID])
        if PeakDict.has_key(chr):
            pass
        else:
            PeakDict[chr] = {}
        PeakDict[chr][pos] = 1

    if regions.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regions
    elif regions.endswith('.gz') or regions.endswith('.bgz'):
        cmd = 'zcat ' + regions
    elif regions.endswith('.zip'):
        cmd = 'unzip -p ' + regions
    else:
        cmd = 'cat ' + regions
    RN = 0
    P = os.popen(cmd, "r")
    peakline = 'line'
    while peakline != '':
        ClusterDict = {}
        peakline = P.readline().strip()
        if peakline == '':
            break
        if peakline.startswith('#'):
            continue
        fields = peakline.strip().split('\t')
        chr = fields[regionchrFieldID]
        RL = int(fields[regionleftFieldID])
        RR = int(fields[regionrightFieldID])
        if doUnstranded:
            strand = '+'
        else:
            strand = fields[regionstrandFieldID]
        if PeakDict.has_key(chr):
            pass
        else:
            continue
        if doLabel:
            label = fields[labelFieldID]
        else:
            label = chr + '_' + str(RL) + '-' + str(RR)
        print label
        if strand == '+':	
            outfile = open(outprefix + '.' + label + '_for.txt','w')
        if strand == '-':
            outfile = open(outprefix + '.' + label + '_rev.txt','w')
        ps = []
        for i in range(RL,RR):
            ps.append(i)
        list1 = Set(PeakDict[chr].keys())
        list2 = Set(ps)
        positions = list(list1 & list2)
        positions.sort()
        if strand == '-':
            positions.reverse()
        outline = '#' + label
        for pos in positions:
            outline = outline + '\t' + str(pos)
        outfile.write(outline + '\n')
        cmd = tabix + ' ' + reads + ' ' + chr + ':' + str(pos - R) + '-' + str(pos + R)
        p2 = os.popen(cmd, "r")
        line = 'line'
        maxScoresLength = 0
        reads = []
        while line != '':
            line = p2.readline().strip()
            if line == '':
                break
            fields = line.strip().split('\t')
            read_left = int(fields[1])
            read_right = int(fields[2])
            RN += 1
            if RN % 100000 == 0:
                print RN, 'reads processed'
            if len(positions) == 0:
                continue
            if read_left <= min(positions) - R and read_right >= max(positions) + R:
                pass
            else:
                continue
            cgs = fields[6].split(',')
            loglike = fields[7].split(',')
            t = zip(cgs,loglike)
            RD = dict((int(x), float(y)) for x, y in t)
            MS = 0
            for i in range(RL,RR):
                if RD.has_key(i):
                    MS += 1
            maxScoresLength = max(MS,maxScoresLength)
            reads.append((MS,cgs,loglike))
        for (MS,cgs,loglike) in reads:
            if MS < maxScoresLength:
                continue
            t = zip(cgs,loglike)
            RD = dict((int(x), float(y)) for x, y in t)
            outline = chr + ':' + str(RL) + '-' + str(RR)
            if doCluster:
                cluster = []
                clustervalues = []
            for pos in positions:
                (A,B) = (alph,bet)
                for i in range(pos - R, pos + R):
                    if RD.has_key(i):
                       p = RD[i]
                       Z = int(PSS*p)
                       A = A + Z
                       B = B + PSS - Z
                if beta.mean(A,B) < 0.5:
                    final_p = 1
                else:
                    final_p = 0
                outline = outline + '\t' + str(final_p)
                if doCluster:
                    cluster.append(final_p)
            outfile.write(outline + '\n')
            if doCluster:
                MS = 0
                for i in range(RL,RR):
                    if RD.has_key(i):
                        clustervalues.append(RD[i])
                        MS += 1
                    else:
                        clustervalues.append(0.0)
                scoresset = list(Set(clustervalues))
                if len(scoresset) == 1 and scoresset[0] == 0:
                    continue
                if strand == '-':
                    clustervalues.reverse()
                cluster = tuple(cluster)
                if ClusterDict.has_key(cluster):
                    pass
                else:
                    ClusterDict[cluster] = []
                ClusterDict[cluster].append(clustervalues)

        outfile.close()

        if doCluster:
            print 'outputing read-level clustering'
            FinalMatrix = []
            FinalMatrixStates = []
            X = []
            for cluster in ClusterDict.keys():
                X.append(list(cluster))
            Z = linkage(X, method='ward', metric='euclidean', optimal_ordering=True)
            clusters = fcluster(Z, 0, criterion='distance')
            CDict = {}
            for i in range(len(clusters)):
                C = clusters[i]
                if CDict.has_key(C):
                    pass
                else:
                    CDict[C] = []
                CDict[C].append(i)
            Cs = CDict.keys()
            Cs.sort()
            for C in Cs:
                for k in CDict[C]:
                    cluster = tuple(X[k])
                MatrixStates = []
                for i in range(len(ClusterDict[cluster])):
                    TTTs = []
                    for T in range(RR - RL):
                        TTTs.append(0.01)
                    for pos in positions:
                        if cluster[positions.index(pos)] == 1:
                            for j in range(max(RL,pos - R), min(RR,pos + R)):
                                TTTs[j - RL] = 1
                    if strand == '-':
                        TTTs.reverse()
                    MatrixStates.append(TTTs)
                if len(ClusterDict[cluster]) == 1:
                    Matrix = ClusterDict[cluster]
                else:
                    XX = []
                    for scores in ClusterDict[cluster]:
                        XX.append(scores)
                    ZZ = linkage(XX, method='ward', metric='euclidean', optimal_ordering=True)
                    Cclusters = fcluster(ZZ, 0, criterion='distance')
                    CCDict = {}
                    for i in range(len(Cclusters)):
                        CC = Cclusters[i]
                        if CCDict.has_key(CC):
                            pass
                        else:
                            CCDict[CC] = []
                        CCDict[CC].append(i)
                    CCs = CCDict.keys()
                    CCs.sort()
                    Matrix = []
                    for CC in CCs:
                        for kk in CCDict[CC]:
                            Matrix.append(XX[kk])
                for scores in Matrix:
                    newscores = []
                    for i in range(0,len(scores),window):
                        s = 0
                        ss = []
                        for j in range(i,min(i+window,len(scores))):
                            if scores[j] == 0:
                                pass
                            else:
                                ss.append(scores[j])
                        if len(ss) > 0.0:
                            s += np.mean(ss)
                        if s != 0.0:
                            if s > 0.5:
                                newscores.append(1.0)
                            else:
                                newscores.append(0.01)
                        else:
                            newscores.append(0.0)
                    FinalMatrix.append(newscores)
                for scores in MatrixStates:
                    newscores = []
                    for i in range(0,len(scores),window):
                        s = 0.0
                        ss = []
                        for j in range(i,min(i+window,len(scores))):
                            if scores[j] == 0:
                                pass
                            else:
                                ss.append(scores[j])
                        if len(ss) > 0:
                            s += np.mean(ss)
                        if s != 0:
                            if s < 0.5:
                                newscores.append(1.0)
                            else:
                                newscores.append(0.01)
                        else:
                            newscores.append(0.0)
                    FinalMatrixStates.append(newscores)

            Xaxis = []
            Yaxis = []
            ColorScores = []

            for i in range(len(FinalMatrix)):
                for j in range(len(FinalMatrix[i])):
                    Xaxis.append(j)
                    Yaxis.append(i)
                ColorScores += FinalMatrix[i]

            newXaxis = []
            newYaxis = []
            newColorScores = []
            for i in range(len(ColorScores)):
                if math.fabs(ColorScores[i]) > 0.0:
                    newColorScores.append(ColorScores[i])
                    newXaxis.append(Xaxis[i])
                    newYaxis.append(Yaxis[i])

            aspectRatio = ((RR - RL)/(window+0.0))/len(FinalMatrix)

            rect = 0.10,0.10,0.8,0.8
            fig = figure(figsize=(20*resize, 20*resize/aspectRatio))
            ax = fig.add_subplot(1,1,1,aspect='equal')
            ax = fig.add_axes(rect)
            lowerlimitX = min(Xaxis) - 2*window
            upperlimitX = max(Xaxis) + 2*window
            lowerlimitY = min(Yaxis) - 1
            upperlimitY = max(Yaxis) + 1
            if SPedge == 'none':
                ax.scatter(newXaxis, newYaxis, marker='o', c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            else:
                ax.scatter(newXaxis, newYaxis, marker='o', edgecolor=SPedge, c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            ax.set_xlim(lowerlimitX,upperlimitX)
            ax.set_ylim(lowerlimitY,upperlimitY)
            xticks = ax.get_xticks()
            yticks = ax.get_yticks()

            if strand == '+':
                 outfileprefix = outprefix + '.' + label + '_for'
            if strand == '-':
                 outfileprefix = outprefix + '.' + label + '_rev'
            savefig(outfileprefix + '.raw_data.scatter.png')
            savefig(outfileprefix + '.raw_data.scatter.eps', format='eps')

            Xaxis = []
            Yaxis = []
            ColorScores = []

            for i in range(len(FinalMatrixStates)):
                for j in range(len(FinalMatrixStates[i])):
                    Xaxis.append(j)
                    Yaxis.append(i)
                ColorScores += FinalMatrixStates[i]

            newXaxis = []
            newYaxis = []
            newColorScores = []
            for i in range(len(ColorScores)):
                if math.fabs(ColorScores[i]) > 0.0:
                    newColorScores.append(ColorScores[i])
                    newXaxis.append(Xaxis[i])
                    newYaxis.append(Yaxis[i])

            aspectRatio = ((RR - RL)/(window+0.0))/len(FinalMatrixStates)

            rect = 0.10,0.10,0.8,0.8
            fig = figure(figsize=(20*resize, 20*resize/aspectRatio))
            ax = fig.add_subplot(1,1,1,aspect='equal')
            ax = fig.add_axes(rect)
            lowerlimitX = min(Xaxis) - 2*window
            upperlimitX = max(Xaxis) + 2*window
            lowerlimitY = min(Yaxis) - 1
            upperlimitY = max(Yaxis) + 1
            if SPedge == 'none':
                ax.scatter(newXaxis, newYaxis, marker='o', c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            else:
                ax.scatter(newXaxis, newYaxis, marker='o', edgecolor=SPedge, c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            ax.set_xlim(lowerlimitX,upperlimitX)
            ax.set_ylim(lowerlimitY,upperlimitY)
            xticks = ax.get_xticks()
            yticks = ax.get_yticks()

            if strand == '+':
                 outfileprefix = outprefix + '.' + label + '_for'
            if strand == '-':
                 outfileprefix = outprefix + '.' + label + '_rev'
            savefig(outfileprefix + '.states.scatter.png')
            savefig(outfileprefix + '.states.scatter.eps', format='eps')

run()

