##################################
#                                #
# Last modified 2018/05/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import random
import os
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from sets import Set
import matplotlib, copy
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.collections import PatchCollection


def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s methylation_reads_all.tsv peak_list chrFieldID leftFieldID rightFieldID strandFieldID outfile_prefix [-resize factor] [-subset N] [-label fieldID] [-minCov fraction] [-minReads N] [-unstranded] [-heatmap path_to_heatmap.py x_pixel_size y_pixel_size colorscheme width(inches,dpi) minScore maxScore] [-scatterPlot colorscheme minScore maxScore color|none] [-window bp] [-deleteMatrix] [-binarize threshold]' % sys.argv[0]
        print '\tThe methylation_reads_all.tsv file is expected to look like this:'
        print '\t\tchrom   start   end     strand  read_name       seq     cgs     log_like'
        print '\t\tchr2L   384364  384936  +       25742d91-f901-4e32-82bf-749187c72c59    .....................$...........................$.................................................................................................0........0.....'
        print '\t\t..........$.......................$..........................0.................0.....0......0...............0............*..*...*.....*..*..*....*.....*.....*.....*.....*...*....*...........................*..............*....*.......'
        print '\t\t.....*.....*....*...*.....*....*....*........*..*.....*.........................*......................*.*.............*..*.......*..................*......*........*.....*....        384385,384413,384511,384536,384560,384587,384605,3'
        print '84634   0.57,-1.28,-4.26,1.86,-0.49,-3.15,-4.87,-3.3'
        print '\tUse the [-subset] option if you want only N of the fragments; the script will pick the N fragments best covering each region, and will discard regions with fewer than N covering fragments'
        print '\tUse the [-label] option if you want regions to be labeled with something other than their coordinates'
        print '\tThe [-heatmap] option will generate png heatmaps instead of text file matrices'
        print '\tThe [-minCov] option will remove all fragments that cover the region at less than the specified fraction'
        sys.exit(1)

    reads = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    strandFieldID = int(sys.argv[6])
    outprefix = sys.argv[7]

#    GenomeDict={}
#    sequence=''
#    inputdatafile = open(fasta)
#    for line in inputdatafile:
#        if line[0]=='>':
#            if sequence != '':
#                GenomeDict[chr] = ''.join(sequence).upper()
#            chr = line.strip().split('>')[1]
#            sequence=[]
#            Keep=False
#            continue
#        else:
#            sequence.append(line.strip())
#    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished inputting genomic sequence'

    doDeleteMatrix = False
    if '-deleteMatrix' in sys.argv:
        doDeleteMatrix = True
        print 'will delete raw matrix files'

    doBinary = False
    if '-binarize' in sys.argv:
        doBinary = True
        BinaryThreshold = float(sys.argv[sys.argv.index('-binarize') + 1])
        print 'will binarize scores at the ', BinaryThreshold, 'threshold'

    doSubset = False
    if '-subset' in sys.argv:
        doSubset = True
        Nsub = int(sys.argv[sys.argv.index('-subset')+1])
        print 'will only output the most complete', Nsub, 'fragments'

    window = 1
    if '-window' in sys.argv:
        window = int(sys.argv[sys.argv.index('-window')+1])
        print 'will average scores across windows of size', window, 'bp'

    MR = 0
    if '-minReads' in sys.argv:
        MR = int(sys.argv[sys.argv.index('-minReads')+1])
        print 'will only output regions with at least', MR, 'reads covering the region'

    MFC = 0
    if '-minCov' in sys.argv:
        MFC = float(sys.argv[sys.argv.index('-minCov')+1])
        print 'will only output fragments with at least', MFC, 'fractional coverage of the region'

    doLabel = False
    if '-label' in sys.argv:
        doLabel = True
        labelFieldID = int(sys.argv[sys.argv.index('-label')+1])

    doScatterPlot = False
    if '-scatterPlot' in sys.argv:
        doScatterPlot = True
        print 'will output scatter plot'
        SPcs = sys.argv[sys.argv.index('-scatterPlot')+1]
        SPmin = float(sys.argv[sys.argv.index('-scatterPlot')+2])
        SPmax = float(sys.argv[sys.argv.index('-scatterPlot')+3])
        SPedge = sys.argv[sys.argv.index('-scatterPlot')+4]
        resize = 1
        if '-resize' in sys.argv:
            resize = float(sys.argv[sys.argv.index('-resize') + 1])
            print 'will resize scatter plots by a factor of', resize

    doHM = False
    if '-heatmap' in sys.argv:
        doHM = True
        print 'will output heatmap'
        HMpy = sys.argv[sys.argv.index('-heatmap')+1]
        HMxp = int(sys.argv[sys.argv.index('-heatmap')+2])
        HMyp = int(sys.argv[sys.argv.index('-heatmap')+3])
        HMcs = sys.argv[sys.argv.index('-heatmap')+4]
        HMincdpi = sys.argv[sys.argv.index('-heatmap')+5]
        HMmin = sys.argv[sys.argv.index('-heatmap')+6]
        HMmax = sys.argv[sys.argv.index('-heatmap')+7]

    doNS = False
    if '-unstranded' in sys.argv:
        doNS = True
        print 'will not treat regions as stranded'

    RegionDict = {}

    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz'):
        cmd = 'gunzip -c ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        linefields = line.strip().split('\t')
        RN += 1
        chr = linefields[chrFieldID]
        start = max(0,int(linefields[leftFieldID]))
        end = int(linefields[rightFieldID])
        if doNS:
            strand = '+'
        else:
            strand = linefields[strandFieldID]
        if RegionDict.has_key(chr):
            pass
        else:
            RegionDict[chr]={}
        if RegionDict[chr].has_key(start):
            pass
        else:
            RegionDict[chr][start] = {}
        if RegionDict[chr][start].has_key(end):
            pass
        else:
            RegionDict[chr][start][end] = {}
        if doLabel:
            label = linefields[labelFieldID]
        else:
            label= '1'
        RegionDict[chr][start][end][strand] = label

    print 'finished inputting regions'

    ReadDict = {}
    RegionToReadDict = {}
    
    if reads.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + reads
    elif reads.endswith('.gz'):
        cmd = 'gunzip -c ' + reads
    elif reads.endswith('.zip'):
        cmd = 'unzip -p ' + reads
    else:
        cmd = 'cat ' + reads
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('chrom\tstart\tend\tstrand\tread_name\tseq\tcgs\tlog_like'):
            continue
        fields = line.strip().split('\t')
        RN += 1
        if RN % 10000 == 0:
            print RN, 'reads processed'
        chr = fields[0]
        if RegionDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        read = fields[5]
        cgs = fields[6]
        loglike = fields[7]
        ReadDict[RN] = (chr,left,right,strand,read,cgs,loglike)
        for i in range(left,right):
            if RegionDict[chr].has_key(i):
                for j in RegionDict[chr][i].keys():
                    if (min(j,right) - max(left,i))/(j-i-0.0) >= MFC:
                        if RegionToReadDict.has_key((chr,i,j)):
                            pass
                        else:
                            RegionToReadDict[(chr,i,j)] = []
                        RegionToReadDict[(chr,i,j)].append(RN)

    print 'finished inputting reads'

    chrList = RegionDict.keys()
    chrList.sort()
    for chr in chrList:
        posList = RegionDict[chr].keys()
        posList.sort()
        for RL in posList:
            for RR in RegionDict[chr][RL].keys():
                start = RL
                end = RR
                if RegionToReadDict.has_key((chr,RL,RR)):
                    pass
                else:
                    continue
                if len(RegionToReadDict[(chr,RL,RR)]) < 2:
                    continue
                if len(RegionToReadDict[(chr,RL,RR)]) < MR:
                    continue
                if doSubset:
                    if len(RegionToReadDict[(chr,RL,RR)]) < Nsub:
                        continue
                    reads = random.sample(RegionToReadDict[(chr,RL,RR)],Nsub)
                else:
                    reads = RegionToReadDict[(chr,RL,RR)]
                for strand in RegionDict[chr][RL][RR].keys():
                    Matrix = []
                    if doLabel:
                        label = RegionDict[chr][RL][RR][strand]
                    else:
                        if strand == '+':
                            label = chr + '_' + str(RL) + '-' + str(RR) + '_for'
                        if strand == '-':
                            label = chr + '_' + str(RL) + '-' + str(RR) + '_rev'
                    maxScoresLength = 0
                    for RN in reads:
                        scores = []
                        (chr,readleft,readright,readstrand,read,cgs,loglike) = ReadDict[RN]
                        CGscoreDict = {}
                        CGs = cgs.split(',')
                        LLs = loglike.split(',')
                        for i in range(len(CGs)):
                            CG = int(CGs[i])
                            LL = float(LLs[i])
                            CGscoreDict[CG] = LL
                        MS = 0.0
                        for i in range(RL,RR):
                            if CGscoreDict.has_key(i):
                                MS += 1
                        maxScoresLength = max(MS,maxScoresLength)
                    for RN in reads:
                        scores = []
                        (chr,readleft,readright,readstrand,read,cgs,loglike) = ReadDict[RN]
                        CGscoreDict = {}
                        CGs = cgs.split(',')
                        LLs = loglike.split(',')
                        for i in range(len(CGs)):
                            CG = int(CGs[i])
                            LL = float(LLs[i])
                            CGscoreDict[CG] = LL
                        MS = 0.0
                        for i in range(RL,RR):
                            if CGscoreDict.has_key(i):
                                scores.append(CGscoreDict[i])
                                MS += 1
                            else:
                                scores.append(0)
                        scoresset = list(Set(scores))
                        if len(scoresset) == 1 and scoresset[0] == 0:
                            continue
                        if MS/maxScoresLength < MFC:
                            continue
                        Matrix.append((len(CGs),scores))
                    Matrix.sort()
                    Matrix.reverse()
                    NewMatrix = []
                    for (lenCGs,scores) in Matrix:
                        newscores = []
                        for i in range(0,len(scores),window):
                            s = 0
                            ss = []
                            for j in range(i,min(i+window,len(scores))):
                                if scores[j] == 0:
                                    pass
                                else:
                                    ss.append(scores[j])
                            if len(ss) > 0:
                                s += np.mean(ss)
                            if doBinary:
                                if s != 0:
                                    if s >= BinaryThreshold:
                                        newscores.append(1.0)
                                    else:
                                        newscores.append(0.01)
                                else:
                                    newscores.append(0)
                            else:    
                                newscores.append(s)
                        NewMatrix.append(newscores)
                    if len(NewMatrix) < 2:
                        continue
                    if len(NewMatrix) < MR:
                        continue
                    print len(NewMatrix), 'reads retained for', label
                    X = []
                    for scores in NewMatrix:
                        X.append(scores)
                    Z = linkage(X, method='ward', metric='euclidean', optimal_ordering=True)
                    clusters = fcluster(Z, 0, criterion='distance')
                    CDict = {}
                    for i in range(len(clusters)):
                        C = clusters[i]
                        if CDict.has_key(C):
                            pass
                        else:
                            CDict[C] = []
                        CDict[C].append(i)
                    Cs = CDict.keys()
                    Cs.sort()

                    if doScatterPlot:
                        Xaxis = []
                        Yaxis = []
                        for i in range(len(NewMatrix)):
                            for j in range(start,end,window):
                                Xaxis.append(j)
                                Yaxis.append(i)
                        ColorScores = []
                        for C in Cs:
                            for k in CDict[C]:
                                scores = X[k]
                                if strand == '-':
                                    scores.reverse()
                            ColorScores += scores

                        newXaxis = []
                        newYaxis = []
                        newColorScores = []
                        for i in range(len(ColorScores)):
                            if math.fabs(ColorScores[i]) > 0.0:
                                newColorScores.append(ColorScores[i])
                                newXaxis.append(Xaxis[i])
                                newYaxis.append(Yaxis[i])

                        aspectRatio = ((end-start)/(window+0.0))/len(NewMatrix)

                        rect = 0.10,0.10,0.8,0.8
                        fig = figure(figsize=(20*resize, 20*resize/aspectRatio))
                        ax = fig.add_subplot(1,1,1,aspect='equal')
                        ax = fig.add_axes(rect)
                        lowerlimitX=min(Xaxis) - 2*window
                        upperlimitX=max(Xaxis) + 2*window
                        lowerlimitY=min(Yaxis) - 1
                        upperlimitY=max(Yaxis) + 1
                        if SPedge == 'none':
                            ax.scatter(newXaxis, newYaxis, marker='o', c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
                        else:
                            ax.scatter(newXaxis, newYaxis, marker='o', edgecolor=SPedge, c=newColorScores, vmin=SPmin, vmax=SPmax, cmap=SPcs)
                        ax.set_xlim(lowerlimitX,upperlimitX)
                        ax.set_ylim(lowerlimitY,upperlimitY)
                        xticks=ax.get_xticks()
                        yticks=ax.get_yticks()
                        xticklabels=[]
                        yticklabels=[]
                        ax.set_yticklabels(yticklabels,size=0,weight='bold')
                        ax.set_xticklabels(xticklabels,size=0,weight='bold')

                        savefig(outprefix + '.' + label + '.scatter.png')
                        savefig(outprefix + '.' + label + '.scatter.eps', format='eps')

                    outfile = open(outprefix + '.' + label + '.matrix', 'w')
                    outline = '#' + chr
                    for i in range(start,end,window):
                        outline = outline + '\t' + str(i)
                    outfile.write(outline+'\n')
                    for C in Cs:
                        for k in CDict[C]:
                            scores = X[k]
                            if strand == '-':
                                scores.reverse()
                            outline = str(C)
                            for s in scores:
                                if doBinary:
                                    if s >= BinaryThreshold:
                                        outline = outline + '\t' + '1'
                                    else:
                                        outline = outline + '\t' + '0'
                                else:
                                    outline = outline + '\t' + str(s)
                            outfile.write(outline + '\n')
                    outfile.close()
                    if doHM:
                        cmd = 'python ' + HMpy + ' ' + outprefix + '.' + label + '.matrix' + ' ' + str(HMxp) + ' ' + str(HMyp) + ' ' + HMmin + ' ' + HMmax + ' ' + HMcs + ' ' + HMincdpi + ' ' + outprefix + '.' + label + '.matrix.heatmap.png'
                        contents = os.system(cmd)
                    if doDeleteMatrix:
                        cmd = 'rm ' + outprefix + '.' + label + '.matrix'
                        contents = os.system(cmd)

run()

