##################################
#                                #
# Last modified 2018/05/14       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
import h5py    
import numpy as np    
from tombo import tombo_stats
import os

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s 5mC.tombo.per_read_stats 6mA.tombo.per_read_stats genome.fa outfile_prefix [-m5C-only] [-m6A-only] [-CG-only] -GC-only]' % sys.argv[0]
        print '\tnote: the [-CG-only] and [-GC-only] options only apply if they [-m5C-only] has been specified'
        sys.exit(1)

    do5C = True
    do6A = True

    if '-m5C-only' in sys.argv:
        do6A = False
        GConly = False
        CGonly = False
        print 'will only output m5C positions'
        if '-CG-only' in sys.argv:
            CGonly = True
            print 'will only output m5C positions in CpG context'
        if '-GC-only' in sys.argv:
            GConly = True
            print 'will only output m5C positions in GpC context'
        if GConly and CGonly:
            print 'incompatible options, exiting'
            sys.exit(1)

    if '-m6A-only' in sys.argv:
        do5C = False
        print 'will only output m6A positions'

    m5C = sys.argv[1]
    m6A = sys.argv[2]
    fasta = sys.argv[3]
    outprefix = sys.argv[4]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished inputting genomic sequence'

    fm5C = tombo_stats.PerReadStats(m5C)
    fm6A = tombo_stats.PerReadStats(m6A)

    ReadDict = {}

    if do5C:
        outfile = open(outprefix + '.GC_context_stats', 'w')
        outfile2 = open(outprefix + '.GC_context_stats_6mer', 'w')
    
        I = fm5C.__iter__()
        i = 0
        for B in I:
           chr = B[0]
           strand = B[1]
           if i % 100 == 0:
               print '5mC', chr, strand, B[2], B[3], i
           for (pos,ll,rid) in B[4]:
               if chr == 'chrI' and i == 1264 and rid == 0:
                   print pos,ll,rid
               if ReadDict.has_key((chr,strand,i,rid)):
                   pass
               else:
                   ReadDict[(chr,strand,i,rid)] = []
               seq = GenomeDict[chr][pos-1:pos+2]
               seq1 = GenomeDict[chr][pos:pos+2]
               seq2 = GenomeDict[chr][pos-1:pos+1]
               if strand == '+':
                   if GConly:
                       if seq2 == 'GC':
                           ReadDict[(chr,strand,i,rid)].append((pos,ll))
                   elif CGonly:
                       if seq1 == 'CG':
                           ReadDict[(chr,strand,i,rid)].append((pos,ll))
                   else:
                       if seq1 == 'CG' or seq2 == 'GC':
                           ReadDict[(chr,strand,i,rid)].append((pos,ll))
               if strand == '-':
                   seq = getReverseComplement(seq)
                   seq1 = getReverseComplement(seq1)
                   seq2 = getReverseComplement(seq2)
                   if GConly:
                       if seq1 == 'GC':
                           ReadDict[(chr,strand,i,rid)].append((pos,ll))
                   elif CGonly:
                       if seq2 == 'CG':
                           ReadDict[(chr,strand,i,rid)].append((pos,ll))
                   else:
                       if seq2 == 'CG' or seq1 == 'GC':
                           ReadDict[(chr,strand,i,rid)].append((pos,ll))
               outline = seq + '\t' + str(ll)
               outfile.write(outline + '\n')
               if 'GC' in GenomeDict[chr][pos-5:pos+5] or 'CG' in GenomeDict[chr][pos-5:pos+5]:
                   outline = 'CG_GC-6mer' + '\t' + str(ll)
               else:
                   outline = 'no_CG_GC-6mer' + '\t' + str(ll)
               outfile2.write(outline + '\n')
           i+=1

        outfile.close()
        outfile2.close()

    if do6A:
        I = fm6A.__iter__()
        i = 0
        for B in I:
           chr = B[0]
           strand = B[1]
           if i % 100 == 0:
               print 'm6A', chr, strand, B[2], B[3], i
           for (pos,ll,rid) in B[4]:
               if ReadDict.has_key((chr,strand,i,rid)):
                   pass
               else:
                   ReadDict[(chr,strand,i,rid)] = []
#               seq = GenomeDict[chr][pos-1:pos+2]
               ReadDict[(chr,strand,i,rid)].append((pos,ll))
           i+=1

    reads = ReadDict.keys()
    reads.sort()

    outfile = open(outprefix + '.reads.tsv', 'w')

    for (chr,strand,i,rid) in reads:
        ReadDict[(chr,strand,i,rid)].sort()
#        print chr,strand,i,rid
#        print ReadDict[(chr,strand,i,rid)]
        if len(ReadDict[(chr,strand,i,rid)]) == 0:
            print 'skipping:', (chr,strand,i,rid), ReadDict[(chr,strand,i,rid)]
            continue
        outline = chr + '\t' + str(ReadDict[(chr,strand,i,rid)][0][0]) + '\t' + str(ReadDict[(chr,strand,i,rid)][-1][0]) + '\t' + strand + '\t' + str(i) + '|' + str(rid) + '\t' + '.'
        Ps = ''
        LLs = ''
        for (pos,ll) in ReadDict[(chr,strand,i,rid)]:
            Ps = Ps + str(pos) + ','
            LLs = LLs + "{0:.2f}".format(ll) + ','
        outline = outline + '\t' + Ps[0:-1]
        outline = outline + '\t' + LLs[0:-1]
        outfile.write(outline + '\n')

    outfile.close()

#     for B in I:
#         print B[0] # chr
#         print B[1] # strand
#         print B[2] # start
#         print B[3] # end
#         print B[4] # array
#         print B
# 
#     groupC = fm5C['Statistic_Blocks']
#     for block in groupC.keys():
#         print block, groupC[block].keys()
#         Bstats = str(block) + '/block_stats'
#         bs = groupC[Bstats]
#         print Bstats, bs
#         for B in bs:
#             print B
#         Rstats = str(block) + '/read_ids'
#         rids = groupC[Rstats]
#         print rids
#         for read in rids:
#             print 'read:', read
# 
#     groupA = fm6A['Statistic_Blocks']
#     for block in groupA.keys():
#         print block, groupA[block].keys()
#         Bstats = str(block) + '/block_stats'
#         bs = groupC[Bstats]
#         print Bstats, bs
#         for B in bs:
#             print B
#         Rstats = str(block) + '/read_ids'
#         rids = groupC[Rstats]
#         print rids
#         for read in rids:
#             print 'read:', read

    
run()

# f = h5py.File(filename, 'r')

# for key in f.keys():
#     print(key)
