##################################
#                                #
# Last modified 2018/04/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set
import h5py    
import numpy as np    
from tombo import tombo_stats
import os

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s 5mC.tombo.per_read_stats 6mA.tombo.per_read_stats genome.fa outfile_prefix' % sys.argv[0]
        sys.exit(1)

    m5C = sys.argv[1]
    m6A = sys.argv[2]
    fasta = sys.argv[3]
    outprefix = sys.argv[4]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    print 'finished inputting genomic sequence'

    fm5C = tombo_stats.TomboStats(m5C)
    fm6A = tombo_stats.TomboStats(m6A)

    for (chr, strand, pos, score, t) in fm5C.iter_fracs():
        print chr + '\t' + str(pos) + '\t' + strand + '\t' + str(score)

    sys.exit(1)


    outfile = open(outprefix + '.context_stats', 'w')
    outfile2 = open(outprefix + '.context_stats_6mer', 'w')

    I = fm5C.__iter__()
    i = 0
    for B in I:
#       print B
       chr = B[0]
       strand = B[1]
       print chr, strand, i
       for (pos,ll,rid) in B[4]:
#           print pos,ll,rid, GenomeDict[chr][pos], getReverseComplement(GenomeDict[chr][pos])
#           print pos,ll,rid, GenomeDict[chr][pos-1:pos+2], getReverseComplement(GenomeDict[chr][pos-1:pos+2])
#           print pos,ll,rid, GenomeDict[chr][pos:pos+2], getReverseComplement(GenomeDict[chr][pos:pos+2])
#           print pos,ll,rid, GenomeDict[chr][pos-2:pos+2], getReverseComplement(GenomeDict[chr][pos-2:pos+2])
           seq = GenomeDict[chr][pos-1:pos+2]
           if strand == '-':
               seq = getReverseComplement(seq)
           outline = seq + '\t' + str(ll)
           outfile.write(outline + '\n')
           if 'GC' in GenomeDict[chr][pos-5:pos+5] or 'CG' in GenomeDict[chr][pos-5:pos+5]:
               outline = 'CG_GC-6mer' + '\t' + str(ll)
           else:
               outline = 'no_CG_GC-6mer' + '\t' + str(ll)
           outfile2.write(outline + '\n')
       i+=1
#       if i >= 1:
#           break

    outfile.close()
    outfile2.close()

    sys.exit(1)

    I = fm6A.__iter__()
    i = 0
    for B in I:
       print B
       i+=1
       if i>=1:
           break

    sys.exit(1)

    for B in I:
        print B[0] # chr
        print B[1] # strand
        print B[2] # start
        print B[3] # end
        print B[4] # array
        print B

    groupC = fm5C['Statistic_Blocks']
    for block in groupC.keys():
        print block, groupC[block].keys()
        Bstats = str(block) + '/block_stats'
        bs = groupC[Bstats]
        print Bstats, bs
        for B in bs:
            print B
        Rstats = str(block) + '/read_ids'
        rids = groupC[Rstats]
        print rids
        for read in rids:
            print 'read:', read

    groupA = fm6A['Statistic_Blocks']
    for block in groupA.keys():
        print block, groupA[block].keys()
        Bstats = str(block) + '/block_stats'
        bs = groupC[Bstats]
        print Bstats, bs
        for B in bs:
            print B
        Rstats = str(block) + '/read_ids'
        rids = groupC[Rstats]
        print rids
        for read in rids:
            print 'read:', read

    
run()

f = h5py.File(filename, 'r')

for key in f.keys():
    print(key)
