##################################
#                                #
# Last modified 2018/05/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import random
import os
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s methylation_reads_all.tsv peak_list chrFieldID leftFieldID rightFieldID strandFieldID outfile_prefix [-subset N] [-label fieldID] [-minCov fraction] [-minReads N] [-unstranded] [-heatmap path_to_heatmap.py x_pixel_size y_pixel_size colorscheme width(inches,dpi) minScore maxScore] [-window bp] [-deleteMatrix]' % sys.argv[0]
        print '\tThe methylation_reads_all.tsv file is expected to look like this:'
        print '\t\tchromosome      start   end     read_name       log_lik_ratio   log_lik_methylated      log_lik_unmethylated    num_calling_strands     num_cpgs        sequence'
        print '\tUse the [-subset] option if you want only N of the fragments; the script will pick the N fragments best covering each region, and will discard regions with fewer than N covering fragments'
        print '\tUse the [-label] option if you want regions to be labeled with something other than their coordinates'
        print '\tThe [-heatmap] option will generate png heatmaps instead of text file matrices'
        print '\tThe [-minCov] option will remove all fragments that cover the region at less than the specified fraction'
        sys.exit(1)

    reads = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    leftFieldID = int(sys.argv[4])
    rightFieldID = int(sys.argv[5])
    strandFieldID = int(sys.argv[6])
    outprefix = sys.argv[7]

    doDeleteMatrix = False
    if '-deleteMatrix' in sys.argv:
        doDeleteMatrix = True
        print 'will delete raw matrix files'

    doSubset = False
    if '-subset' in sys.argv:
        doSubset = True
        Nsub = int(sys.argv[sys.argv.index('-subset')+1])
        print 'will only output the most complete', Nsub, 'fragments'

    window = 1
    if '-window' in sys.argv:
        window = int(sys.argv[sys.argv.index('-window')+1])
        print 'will average scores across windows of size', window, 'bp'

    MR = 0
    if '-minReads' in sys.argv:
        MR = int(sys.argv[sys.argv.index('-minReads')+1])
        print 'will only output regions with at least', MR, 'reads covering the region'

    MFC = 0
    if '-minCov' in sys.argv:
        MFC = float(sys.argv[sys.argv.index('-minCov')+1])
        print 'will only output fragments with at least', MFC, 'fractional coverage of the region'

    doLabel = False
    if '-label' in sys.argv:
        doLabel = True
        labelFieldID = int(sys.argv[sys.argv.index('-label')+1])

    doHM = False
    if '-heatmap' in sys.argv:
        doHM = True
        print 'will output heatmap'
        HMpy = sys.argv[sys.argv.index('-heatmap')+1]
        HMxp = int(sys.argv[sys.argv.index('-heatmap')+2])
        HMyp = int(sys.argv[sys.argv.index('-heatmap')+3])
        HMcs = sys.argv[sys.argv.index('-heatmap')+4]
        HMincdpi = sys.argv[sys.argv.index('-heatmap')+5]
        HMmin = sys.argv[sys.argv.index('-heatmap')+6]
        HMmax = sys.argv[sys.argv.index('-heatmap')+7]

    doNS = False
    if '-unstranded' in sys.argv:
        doNS = True
        print 'will not treat regions as stranded'

    RegionDict = {}

    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz'):
        cmd = 'gunzip -c ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        linefields = line.strip().split('\t')
        RN += 1
        chr = linefields[chrFieldID]
        start = max(0,int(linefields[leftFieldID]))
        end = int(linefields[rightFieldID])
        if doNS:
            strand = '+'
        else:
            strand = linefields[strandFieldID]
        if RegionDict.has_key(chr):
            pass
        else:
            RegionDict[chr]={}
        if RegionDict[chr].has_key(start):
            pass
        else:
            RegionDict[chr][start] = {}
        if RegionDict[chr][start].has_key(end):
            pass
        else:
            RegionDict[chr][start][end] = {}
        if doLabel:
            label = linefields[labelFieldID]
        else:
            label= '1'
        RegionDict[chr][start][end][strand] = label

    print 'finished inputting regions'

    ReadDict = {}
    RegionToReadDict = {}
    
    if reads.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + reads
    elif reads.endswith('.gz'):
        cmd = 'gunzip -c ' + reads
    elif reads.endswith('.zip'):
        cmd = 'unzip -p ' + reads
    else:
        cmd = 'cat ' + reads
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('chromosome\tstart\tend\tread_name'):
            continue
        fields = line.strip().split('\t')
        RN += 1
        if RN % 100000 == 0:
            print RN, 'lines processed'
        chr = fields[0]
        if RegionDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        read = fields[3]
        loglike = fields[4]
        if ReadDict.has_key(chr):
            pass
        else:
            ReadDict[chr] = {}
        if ReadDict[chr].has_key(read):
            pass
        else:
            ReadDict[chr][read] = {}
            ReadDict[chr][read]['ps'] = []
            ReadDict[chr][read]['lls'] = []
        for i in range(left,right+1):
            ReadDict[chr][read]['ps'].append(i)
            ReadDict[chr][read]['lls'].append(loglike)

    print 'finished inputting reads'

    print ReadDict.keys()

    K=0
    for chr in ReadDict.keys():
        print chr
        for readID in ReadDict[chr].keys():
            K+=1
            if K % 1000 == 0:
                print K
            left = min(ReadDict[chr][readID]['ps'])
            right = max(ReadDict[chr][readID]['ps'])
            for i in range(left,right):
                if RegionDict[chr].has_key(i):
                    for j in RegionDict[chr][i].keys():
                        if (min(j,right) - max(left,i))/(j-i-0.0) >= MFC:
                            if RegionToReadDict.has_key((chr,i,j)):
                                pass
                            else:
                                RegionToReadDict[(chr,i,j)] = []
                            RegionToReadDict[(chr,i,j)].append(readID)

    print 'finished matching reads to regions'

    chrList = RegionDict.keys()
    chrList.sort()
    for chr in chrList:
        posList = RegionDict[chr].keys()
        posList.sort()
        for RL in posList:
            for RR in RegionDict[chr][RL].keys():
                start = RL
                end = RR
                if RegionToReadDict.has_key((chr,RL,RR)):
                    pass
                else:
                    continue
                if len(RegionToReadDict[(chr,RL,RR)]) < 2:
                    continue
                if len(RegionToReadDict[(chr,RL,RR)]) < MR:
                    continue
                if doSubset:
                    if len(RegionToReadDict[(chr,RL,RR)]) < Nsub:
                        continue
                    reads = random.sample(RegionToReadDict[(chr,RL,RR)],Nsub)
                else:
                    reads = RegionToReadDict[(chr,RL,RR)]
                for strand in RegionDict[chr][RL][RR].keys():
                    Matrix = []
                    if doLabel:
                        label = RegionDict[chr][RL][RR][strand]
                    else:
                        if strand == '+':
                            label = chr + '_' + str(RL) + '-' + str(RR) + '_for'
                        if strand == '-':
                            label = chr + '_' + str(RL) + '-' + str(RR) + '_rev'
                    for RN in reads:
                        scores = []
                        CGs = ReadDict[chr][RN]['ps']
                        LLs = ReadDict[chr][RN]['lls']
                        CGscoreDict = {}
                        for i in range(len(CGs)):
                            CG = int(CGs[i])
                            LL = float(LLs[i])
                            CGscoreDict[CG] = LL
                        for i in range(RL,RR):
                            if CGscoreDict.has_key(i):
                                scores.append(CGscoreDict[i])
                            else:
                                scores.append(0)
                        Matrix.append((len(CGs),scores))
                    Matrix.sort()
                    Matrix.reverse()
                    NewMatrix = []
                    for (lenCGs,scores) in Matrix:
                        newscores = []
                        for i in range(0,len(scores),window):
                            s = 0
                            ss = []
                            for j in range(i,min(i+window,len(scores))):
                                if scores[j] == 0:
                                    pass
                                else:
                                    ss.append(scores[j])
                            if len(ss) > 0:
                                s += np.mean(ss)
                            newscores.append(s)
                        NewMatrix.append(newscores)
                    if len(NewMatrix) < 2:
                        continue
                    print len(NewMatrix), 'reads retained for', label
                    X = []
                    for scores in NewMatrix:
                        X.append(scores)
                    Z = linkage(X, method='ward', metric='euclidean', optimal_ordering=True)
                    clusters = fcluster(Z, 0, criterion='distance')
                    CDict = {}
                    for i in range(len(clusters)):
                        C = clusters[i]
                        if CDict.has_key(C):
                            pass
                        else:
                            CDict[C] = []
                        CDict[C].append(i)
                    Cs = CDict.keys()
                    Cs.sort()

                    outfile = open(outprefix + '.' + label + '.matrix', 'w')
                    outline = '#' + chr
                    for i in range(start,end,window):
                        outline = outline + '\t' + str(i)
                    outfile.write(outline+'\n')
                    for C in Cs:
                        for k in CDict[C]:
                            scores = X[k]
                            if strand == '-':
                                scores.reverse()
                            outline = str(C)
                            for s in scores:
                                outline = outline + '\t' + str(s)
                            outfile.write(outline + '\n')
                    outfile.close()
                    if doHM:
                        cmd = 'python ' + HMpy + ' ' + outprefix + '.' + label + '.matrix' + ' ' + str(HMxp) + ' ' + str(HMyp) + ' ' + HMmin + ' ' + HMmax + ' ' + HMcs + ' ' + HMincdpi + ' ' + outprefix + '.' + label + '.matrix.png'
                        contents = os.system(cmd)
                    if doDeleteMatrix:
                        cmd = 'rm ' + outprefix + '.' + label + '.matrix'
                        contents = os.system(cmd)
            
run()

