##################################
#                                #
# Last modified 2018/04/27       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy as np
import random
import os
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s methylation_reads_all.tsv peak_list chrFieldID peakPosition|narrowPeak strandFieldID leftRadius rightRadius outfilename [-subset N] [-minCov fraction] [-minReads N] [-window bp] [-unstranded] [-sortBy fieldID]' % sys.argv[0]
        print '\tThe methylation_reads_all.tsv file is expected to look like this:'
        print '\t\tchrom   start   end     strand  read_name       seq     cgs     log_like'
        print '\t\tchr2L   384364  384936  +       25742d91-f901-4e32-82bf-749187c72c59    .....................$...........................$.................................................................................................0........0.....'
        print '\t\t..........$.......................$..........................0.................0.....0......0...............0............*..*...*.....*..*..*....*.....*.....*.....*.....*...*....*...........................*..............*....*.......'
        print '\t\t.....*.....*....*...*.....*....*....*........*..*.....*.........................*......................*.*.............*..*.......*..................*......*........*.....*....        384385,384413,384511,384536,384560,384587,384605,3'
        print '84634   0.57,-1.28,-4.26,1.86,-0.49,-3.15,-4.87,-3.3'
        print '\tThe [-minCov] option will remove all fragments that cover the region at less than the specified fraction'
        print '\tThe [-minReads] option will discard sites covered by less the indicated number of reads'
        sys.exit(1)

    reads = sys.argv[1]
    peaks = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    position = sys.argv[4]
    doNarrowPeak = False
    if position == 'narrowPeak':
        doNarrowPeak = True
    else:
        peakFieldID = position
    strandFieldID = int(sys.argv[5])
    leftRadius = int(sys.argv[6])
    rightRadius = int(sys.argv[7])
    outfilename = sys.argv[8]

    doSubset = False
    if '-subset' in sys.argv:
        doSubset = True
        Nsub = int(sys.argv[sys.argv.index('-subset')+1])
        print 'will only output the most complete', Nsub, 'fragments'

    window = 1
    if '-window' in sys.argv:
        window = int(sys.argv[sys.argv.index('-window')+1])
        print 'will average scores across windows of size', window, 'bp'

    MR = 0
    if '-minReads' in sys.argv:
        MR = int(sys.argv[sys.argv.index('-minReads')+1])
        print 'will only output regions with at least', MR, 'reads covering the region'

    MFC = 0
    if '-minCov' in sys.argv:
        MFC = float(sys.argv[sys.argv.index('-minCov')+1])
        print 'will only output fragments with at least', MFC, 'fractional coverage of the region'

    doNS = False
    if '-unstranded' in sys.argv:
        doNS = True
        print 'will not treat regions as stranded'

    sortFieldID = 0
    if '-sortBy' in sys.argv:
        sortFieldID = int(sys.argv[sys.argv.index('-sortBy')+1])
        print 'will sort regions by values in column', sortFieldID, '(0-based)'

    RegionDict = {}

    regionList = []

    if peaks.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + peaks
    elif peaks.endswith('.gz'):
        cmd = 'gunzip -c ' + peaks
    elif peaks.endswith('.zip'):
        cmd = 'unzip -p ' + peaks
    else:
        cmd = 'cat ' + peaks
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        linefields = line.strip().split('\t')
        RN += 1
        chr = linefields[chrFieldID]
        if doNarrowPeak:
            peak = int(linefields[1]) + int(linefields[9])
        else:
            peak = int(linefields[peakFieldID])
        if doNS:
            strand = '+'
        else:
            strand = linefields[strandFieldID]
        if RegionDict.has_key(chr):
            pass
        else:
            RegionDict[chr]={}
        start = peak - leftRadius
        end = peak + rightRadius
        if RegionDict[chr].has_key(start):
            pass
        else:
            RegionDict[chr][start] = {}
        if RegionDict[chr][start].has_key(end):
            pass
        else:
            RegionDict[chr][start][end] = {}
        RegionDict[chr][start][end][strand] = 1
        sort = linefields[sortFieldID]
        regionList.append((sort,chr,start,end,strand))

    print 'finished inputting regions'

    ReadDict = {}
    RegionToReadDict = {}
    
    if reads.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + reads
    elif reads.endswith('.gz'):
        cmd = 'gunzip -c ' + reads
    elif reads.endswith('.zip'):
        cmd = 'unzip -p ' + reads
    else:
        cmd = 'cat ' + reads
    RN = 0
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('chrom\tstart\tend\tstrand\tread_name\tseq\tcgs\tlog_like'):
            continue
        fields = line.strip().split('\t')
        RN += 1
        if RN % 10000 == 0:
            print RN, 'reads processed'
        chr = fields[0]
        if RegionDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        strand = fields[3]
        read = fields[5]
        cgs = fields[6]
        loglike = fields[7]
        ReadDict[RN] = (chr,left,right,strand,read,cgs,loglike)
        for i in range(left,right):
            if RegionDict[chr].has_key(i):
                for j in RegionDict[chr][i].keys():
                    if (min(j,right) - max(left,i))/(j-i-0.0) >= MFC:
                        if RegionToReadDict.has_key((chr,i,j)):
                            pass
                        else:
                            RegionToReadDict[(chr,i,j)] = []
                        RegionToReadDict[(chr,i,j)].append(RN)

    print 'finished inputting reads'

    outfile = open(outfilename, 'w')
    outline = '#'
    for i in range(-leftRadius,rightRadius,window):
        outline = outline + '\t' + str(i)
    outfile.write(outline+'\n')

    regionList.sort()
    regionList.reverse()

    RK = 0
    for (sort,chr,RL,RR,strand) in regionList:
        if RegionToReadDict.has_key((chr,RL,RR)):
            pass
        else:
            continue
        if len(RegionToReadDict[(chr,RL,RR)]) < 2:
            continue
        if len(RegionToReadDict[(chr,RL,RR)]) < MR:
            continue
        if doSubset:
            if len(RegionToReadDict[(chr,RL,RR)]) < Nsub:
                continue
            reads = random.sample(RegionToReadDict[(chr,RL,RR)],Nsub)
        else:
            reads = RegionToReadDict[(chr,RL,RR)]
        for strand in RegionDict[chr][RL][RR].keys():
            Matrix = []
            for RN in reads:
                RK += 1
                scores = []
                (chr,readleft,readright,readstrand,read,cgs,loglike) = ReadDict[RN]
                CGscoreDict = {}
                CGs = cgs.split(',')
                LLs = loglike.split(',')
                for i in range(len(CGs)):
                    CG = int(CGs[i])
                    LL = float(LLs[i])
                    CGscoreDict[CG] = LL
                for i in range(RL,RR):
                    if CGscoreDict.has_key(i):
                        scores.append(CGscoreDict[i])
                    else:
                        scores.append(0)
                Matrix.append((len(CGs),scores))
            Matrix.sort()
            Matrix.reverse()
            for (lenCGs,scores) in Matrix:
                outline = chr + ':' + str(RL) + '-' + str(RR) + '_' + str(RK)
                newscores = []
                for i in range(0,len(scores),window):
                    s = 0
                    ss = []
                    for j in range(i,min(i+window,len(scores))):
                        if scores[j] == 0:
                            pass
                        else:
                            ss.append(scores[j])
                    if len(ss) > 0:
                        s += np.mean(ss)
                    outline = outline + '\t' + str(s)
                outfile.write(outline + '\n')

    outfile.close()
            
run()

