##################################
#                                #
# Last modified 2018/12/10       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats import beta
from scipy.stats import binom
from scipy.stats import norm
import random
import os
import math
from sets import Set
from sklearn.metrics import normalized_mutual_info_score as NMIS
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster


def run():

    if len(sys.argv) < 12:
        print 'usage: python %s methylation_reads_all.tsv regions.bed chrFieldID leftField rightField strandFieldID peaks chrFieldID posFieldID radius tabix_location outfile [-unstranded] [-label fieldID] [-cluster window colorscheme minScore maxScore color|none resize]' % sys.argv[0]
        print '\Note: the script assumes Tombo 1.3 probabilities, a tabix indexed reads file, and uses a beta distribution prior of (10,10) by default'
        sys.exit(1)

    reads = sys.argv[1]
    regions = sys.argv[2]
    regionchrFieldID = int(sys.argv[3])
    regionleftFieldID = int(sys.argv[4])
    regionrightFieldID = int(sys.argv[5])
    regionstrandFieldID = int(sys.argv[6])
    peaks = sys.argv[7]
    chrFieldID = int(sys.argv[8])
    posFieldID = int(sys.argv[9])
    R = int(sys.argv[10])
    tabix = sys.argv[11]
    outprefix = sys.argv[12]

    kmeansR = sys.argv[0].rpartition('/')[0] + '/kmeans_tsv.R'

    doLabel = False
    if '-label' in sys.argv:
        doLabel = True
        labelFieldID = int(sys.argv[sys.argv.index('-label') + 1])

    doCluster = False
    if '-cluster' in sys.argv:
        doCluster = True
        window = int(sys.argv[sys.argv.index('-cluster') + 1])
        SPcs = sys.argv[sys.argv.index('-cluster') + 2]
        SPmin = float(sys.argv[sys.argv.index('-cluster') + 3])
        SPmax = float(sys.argv[sys.argv.index('-cluster') + 4])
        SPedge = sys.argv[sys.argv.index('-cluster') + 5]
        resize = float(sys.argv[sys.argv.index('-cluster') + 6])

    doUnstranded = False
    if '-unstranded' in sys.argv:
        doUnstranded = True

    alph = 10
    bet = 10
    PSS = 100

    doBinary = True
    BinaryThreshold = 0.5

    PeakDict = {}
    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz') or peaks.endswith('.bgz'):
        cmd = 'zcat ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    P = os.popen(cmd, "r")
    peakline = 'line'
    while peakline != '':
        peakline = P.readline().strip()
        if peakline == '':
            break
        if peakline.startswith('#'):
            continue
        fields = peakline.strip().split('\t')
        chr = fields[chrFieldID]
        pos = int(fields[posFieldID])
        if PeakDict.has_key(chr):
            pass
        else:
            PeakDict[chr] = {}
        PeakDict[chr][pos] = 1

    if regions.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regions
    elif regions.endswith('.gz') or regions.endswith('.bgz'):
        cmd = 'zcat ' + regions
    elif regions.endswith('.zip'):
        cmd = 'unzip -p ' + regions
    else:
        cmd = 'cat ' + regions
    RN = 0
    P = os.popen(cmd, "r")
    peakline = 'line'
    while peakline != '':
        ClusterDict = {}
        peakline = P.readline().strip()
        if peakline == '':
            break
        if peakline.startswith('#'):
            continue
        fields = peakline.strip().split('\t')
        chr = fields[regionchrFieldID]
        RL = int(fields[regionleftFieldID])
        RR = int(fields[regionrightFieldID])
        if doUnstranded:
            strand = '+'
        else:
            strand = fields[regionstrandFieldID]
        if PeakDict.has_key(chr):
            pass
        else:
            continue
        if doLabel:
            label = fields[labelFieldID]
        else:
            label = chr + '_' + str(RL) + '-' + str(RR)
        print label
        if strand == '+':	
            outfile = open(outprefix + '.' + label + '_for.txt','w')
        if strand == '-':
            outfile = open(outprefix + '.' + label + '_rev.txt','w')
        ps = []
        for i in range(RL,RR):
            ps.append(i)
        list1 = Set(PeakDict[chr].keys())
        list2 = Set(ps)
        positions = list(list1 & list2)
        positions.sort()
        if strand == '-':
            positions.reverse()
        outline = '#' + label
        for pos in positions:
            outline = outline + '\t' + str(pos)
        outfile.write(outline + '\n')
        cmd = tabix + ' ' + reads + ' ' + chr + ':' + str(pos - R) + '-' + str(pos + R)
        p2 = os.popen(cmd, "r")
        line = 'line'
        maxScoresLength = 0
        reads = []
        while line != '':
            line = p2.readline().strip()
            if line == '':
                break
            fields = line.strip().split('\t')
            read_left = int(fields[1])
            read_right = int(fields[2])
            RN += 1
            if RN % 100000 == 0:
                print RN, 'reads processed'
            if len(positions) == 0:
                continue
            if read_left <= min(positions) - R and read_right >= max(positions) + R:
                pass
            else:
                continue
            cgs = fields[6].split(',')
            loglike = fields[7].split(',')
            t = zip(cgs,loglike)
            RD = dict((int(x), float(y)) for x, y in t)
            MS = 0
            for i in range(RL,RR):
                if RD.has_key(i):
                    MS += 1
            maxScoresLength = max(MS,maxScoresLength)
            reads.append((MS,cgs,loglike))
        for (MS,cgs,loglike) in reads:
            if MS < maxScoresLength:
                continue
            t = zip(cgs,loglike)
            RD = dict((int(x), float(y)) for x, y in t)
            outline = chr + ':' + str(RL) + '-' + str(RR)
            if doCluster:
                cluster = []
                scores = []
            for pos in positions:
                (A,B) = (alph,bet)
                for i in range(pos - R, pos + R):
                    if RD.has_key(i):
                       p = RD[i]
                       Z = int(PSS*p)
                       A = A + Z
                       B = B + PSS - Z
                if beta.mean(A,B) < 0.5:
                    final_p = 1
                else:
                    final_p = 0
                outline = outline + '\t' + str(final_p)
                if doCluster:
                    cluster.append(final_p)
            outfile.write(outline + '\n')
            if doCluster:
                MS = 0
                for i in range(RL,RR):
                    if RD.has_key(i):
                        scores.append(RD[i])
                        MS += 1
                    else:
                        scores.append(0)
                scoresset = list(Set(scores))
                if len(scoresset) == 1 and scoresset[0] == 0:
                    continue
                newscores = []
                for i in range(0,len(scores),window):
                    s = 0.0
                    ss = []
                    for j in range(i,min(i+window,len(scores))):
                        if scores[j] == 0:
                            pass
                        else:
                            ss.append(scores[j])
                        if len(ss) > 0:
                            s += np.mean(ss)
                        if doBinary:
                            if s != 0:
                                if s >= BinaryThreshold:
                                    newscores.append(1.0)
                                else:
                                    newscores.append(0.01)
                            else:
                                newscores.append(0)
                        else:    
                            newscores.append(s)
                if strand == '-':
                    newscores.reverse()
                cluster = tuple(cluster)
                if ClusterDict.has_key(cluster):
                    pass
                else:
                    ClusterDict[cluster] = []
                ClusterDict[cluster].append(newscores)

        outfile.close()

        if doCluster:

            print 'outputing read-level clustering'

#            TMPfile = outprefix + 'tmp.tsv'
#            outTMPfile = open(TMPfile, 'w')
#            X = 0
#            for cluster in ClusterDict.keys():
#                scores = list(cluster)
#                X += 1
#                outline = 'R' + str(X)
#                for XS in scores:
#                    outline = outline + '\t' + str(XS)
#                outTMPfile.write(outline + '\n')

#            outTMPfile.close()

#            cmd = 'module load R/3.5.1'
#            os.system(cmd)

#            cmd = 'Rscript ' + kmeansR + ' --file=' + TMPfile
#            print cmd
#            os.system(cmd)

#            print 'finished state-level clustering'

#            newTMPfile = outprefix + 'tmp_new_order.tsv'
#            TMPlines = open(newTMPfile)

#            StateClusters = []

#            for tline in TMPlines:
#                if tline.startswith('ID\t'):
#                    continue
#                tfields = tline.split('\t')
#                i+=1
#                scores = []
#                for tk in range(1,len(tfields)):
#                    TS = float(tfields[tk])
#                    scores.append(TS)
#                StateClusters.append(scores)

#            cmd = 'rm ' + outprefix + 'tmp.tsv'
#            os.system(cmd)
#            cmd = 'rm ' + newTMPfile
#            os.system(cmd)

            FinalMatrix = []
            FinalMatrixStates = []

            StateClusters = []

            X = []
            for cluster in ClusterDict.keys():
                X.append(list(cluster))
            Z = linkage(X, method='ward', metric='euclidean', optimal_ordering=True)
            clusters = fcluster(Z, 0, criterion='distance')
            CDict = {}
            for i in range(len(clusters)):
                C = clusters[i]
                StateClusters.append(X[C-1])

            MatrixStates = []
            for scores in StateClusters:
                cluster = tuple(scores)
                for i in range(len(ClusterDict[cluster])):
                    TTTs = []
                    for T in range(RR - RL):
                        TTTs.append(0.01)
                    for pos in positions:
                        if cluster[positions.index(pos)] == 1:
                            for j in range(max(RL,pos - R), min(RR,pos + R)):
                                TTTs[j - RL] = 1
                    if strand == '-':
                        TTTs.reverse()
                    MatrixStates.append(TTTs)
            for scores in MatrixStates:
                newscores = []
                for i in range(0,len(scores),window):
                    s = 0.0
                    ss = []
                    for j in range(i,min(i+window,len(scores))):
                        if scores[j] == 0:
                            pass
                        else:
                            ss.append(scores[j])
                    if len(ss) > 0:
                        s += np.mean(ss)
                    if s != 0:
                        if s < 0.5:
                            newscores.append(1.0)
                        else:
                            newscores.append(0.01)
                    else:
                        newscores.append(0.0)
                FinalMatrixStates.append(newscores)

            SCID = 0
            for scores in StateClusters:
                SCID += 1
                cluster = tuple(scores)
                print cluster, SCID
                if len(ClusterDict[cluster]) == 1:
                    Matrix = ClusterDict[cluster]
                else:
                    try:
                        TMPfile = outprefix + 'tmp.tsv'
                        outTMPfile = open(TMPfile, 'w')
                        X = 0
                        for scores in ClusterDict[cluster]:
                            X += 1
                            outline = 'R' + str(X)
                            for XS in scores:
                                outline = outline + '\t' + str(XS)
                            outTMPfile.write(outline + '\n')

                        outTMPfile.close()

                        cmd = 'Rscript ' + kmeansR + ' --file=' + TMPfile
                        print cmd
                        os.system(cmd)

                        cmd = 'rm ' + outprefix + 'tmp.tsv'
                        os.system(cmd)

                        newTMPfile = outprefix + 'tmp_new_order.tsv'
                        TMPlines = open(newTMPfile)

                        Matrix = []

                        for tline in TMPlines:
                            if tline.startswith('ID\t'):
                                continue
                            tfields = tline.split('\t')
                            i+=1
                            scores = []
                            for tk in range(1,len(tfields)):
                                TS = float(tfields[tk])
                                scores.append(TS)
                            Matrix.append(scores)

                        cmd = 'rm ' + newTMPfile
                        os.system(cmd)
                    except:
                        XX = []
                        for scores in ClusterDict[cluster]:
                            XX.append(scores)
                        ZZ = linkage(XX, method='ward', metric='euclidean', optimal_ordering=True)
                        Cclusters = fcluster(ZZ, 0, criterion='distance')
                        CCDict = {}
                        for i in range(len(Cclusters)):
                            CC = Cclusters[i]
                            if CCDict.has_key(CC):
                                pass
                            else:
                                CCDict[CC] = []
                            CCDict[CC].append(i)
                        CCs = CCDict.keys()
                        CCs.sort()
                        Matrix = []
                        for CC in CCs:
                            for kk in CCDict[CC]:
                                Matrix.append(XX[kk])
                    print 'finished state-level clustering for state', SCID, 'out of ', len(ClusterDict.keys())

                for scores in Matrix:
                    FinalMatrix.append(scores)

            Xaxis = []
            Yaxis = []
            ColorScores = []

            for i in range(len(FinalMatrix)):
                for j in range(len(FinalMatrix[i])):
                    Xaxis.append(j)
                    Yaxis.append(i)
                ColorScores += FinalMatrix[i]

            newXaxis = []
            newYaxis = []
            newColorScores = []
            for i in range(len(ColorScores)):
                if math.fabs(ColorScores[i]) > 0.0:
                    newColorScores.append(ColorScores[i])
                    newXaxis.append(Xaxis[i])
                    newYaxis.append(Yaxis[i])

            aspectRatio = ((RR - RL)/(window+0.0))/len(FinalMatrix)

            rect = 0.10,0.10,0.8,0.8
            fig = figure(figsize=(20*resize, 20*resize/aspectRatio))
            ax = fig.add_subplot(1,1,1,aspect='equal')
            ax = fig.add_axes(rect)
            lowerlimitX = min(Xaxis) - 2*window
            upperlimitX = max(Xaxis) + 2*window
            lowerlimitY = min(Yaxis) - 1
            upperlimitY = max(Yaxis) + 1
            if SPedge == 'none':
                ax.scatter(newXaxis, newYaxis, marker='o', c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            else:
                ax.scatter(newXaxis, newYaxis, marker='o', edgecolor=SPedge, c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            ax.set_xlim(lowerlimitX,upperlimitX)
            ax.set_ylim(lowerlimitY,upperlimitY)
            xticks = ax.get_xticks()
            yticks = ax.get_yticks()

            if strand == '+':
                 outfileprefix = outprefix + '.' + label + '_for'
            if strand == '-':
                 outfileprefix = outprefix + '.' + label + '_rev'
            savefig(outfileprefix + '.raw_data.scatter.png')
            savefig(outfileprefix + '.raw_data.scatter.eps', format='eps')

            Xaxis = []
            Yaxis = []
            ColorScores = []

            for i in range(len(FinalMatrixStates)):
                for j in range(len(FinalMatrixStates[i])):
                    Xaxis.append(j)
                    Yaxis.append(i)
                ColorScores += FinalMatrixStates[i]

            newXaxis = []
            newYaxis = []
            newColorScores = []
            for i in range(len(ColorScores)):
                if math.fabs(ColorScores[i]) > 0.0:
                    newColorScores.append(ColorScores[i])
                    newXaxis.append(Xaxis[i])
                    newYaxis.append(Yaxis[i])

            aspectRatio = ((RR - RL)/(window+0.0))/len(FinalMatrixStates)

            rect = 0.10,0.10,0.8,0.8
            fig = figure(figsize=(20*resize, 20*resize/aspectRatio))
            ax = fig.add_subplot(1,1,1,aspect='equal')
            ax = fig.add_axes(rect)
            lowerlimitX = min(Xaxis) - 2*window
            upperlimitX = max(Xaxis) + 2*window
            lowerlimitY = min(Yaxis) - 1
            upperlimitY = max(Yaxis) + 1
            if SPedge == 'none':
                ax.scatter(newXaxis, newYaxis, marker='o', c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            else:
                ax.scatter(newXaxis, newYaxis, marker='o', edgecolor=SPedge, c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
            ax.set_xlim(lowerlimitX,upperlimitX)
            ax.set_ylim(lowerlimitY,upperlimitY)
            xticks = ax.get_xticks()
            yticks = ax.get_yticks()

            if strand == '+':
                 outfileprefix = outprefix + '.' + label + '_for'
            if strand == '-':
                 outfileprefix = outprefix + '.' + label + '_rev'
            savefig(outfileprefix + '.states.scatter.png')
            savefig(outfileprefix + '.states.scatter.eps', format='eps')

    cmd = 'rm ' + outprefix + 'tmp.tsv'
    os.system(cmd)
    cmd = 'rm ' + newTMPfile
    os.system(cmd)

run()

