##################################
#                                #
# Last modified 2018/10/04       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import re
import string
import math

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence


def run():

    if len(sys.argv) < 5:
        print 'usage: python %s 5p.plus.wig 5p.minus.wig genome.fa k-mer_size outfile' % sys.argv[0]
        print '\twig files can be a .gz, .bz2, and .zip file'
        sys.exit(1)

    pluswig = sys.argv[1]
    minuswig = sys.argv[2]
    fasta = sys.argv[3]
    K = int(sys.argv[4])
    outfilename = sys.argv[5]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    KDict = {}
    TotalScore = 0.0
    TotalKmers = 0.0

    kmerGenomeCountsDict = {}
    kmerGenomeCountsDict[''] = 0
    for i in range(K):
        kmers = kmerGenomeCountsDict.keys()
        for kmer in kmers:
            for B in ['A','C','G','T']:
                newkmer = kmer + B
                kmerGenomeCountsDict[newkmer] = 0
            del kmerGenomeCountsDict[kmer]

    for kmer in kmerGenomeCountsDict.keys():
        KDict[kmer] = 0

    L = 0
    for chr in GenomeDict.keys():
        print chr
        for i in range(0,len(GenomeDict[chr])-K):
            L+=1
            if L % 10000000 == 0:
                print str(L/1000000) + 'M positions processed'
            kmer = GenomeDict[chr][i:i+K]
            if kmerGenomeCountsDict.has_key(kmer):
                kmerGenomeCountsDict[kmer] += 1
                TotalKmers += 1

    if pluswig.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + pluswig
    elif pluswig.endswith('.gz'):
        cmd = 'gunzip -c ' + pluswig
    elif pluswig.endswith('.zip'):
        cmd = 'unzip -p ' + pluswig
    else:
        cmd = 'cat ' + pluswig
    p = os.popen(cmd, "r")
    line = 'line'
    LC = 0
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        LC += 1
        if LC % 1000000 == 0:
            print LC, 'lines processed in plus strand file'
        fields = line.strip().split('\t')
        chr = fields[0].split(':')[0]
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        score = float(fields[3])
        for i in range(left,right):
            if i - (K/2) < 0 or i + (K/2) > len(GenomeDict[chr]):
                continue
            kmer = GenomeDict[chr][i - (K/2):i + (K/2)]
            if KDict.has_key(kmer):
                pass
            else:
                continue
            KDict[kmer] += score
            TotalScore += score

    if pluswig.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + pluswig
    elif pluswig.endswith('.gz'):
        cmd = 'gunzip -c ' + pluswig
    elif pluswig.endswith('.zip'):
        cmd = 'unzip -p ' + pluswig
    else:
        cmd = 'cat ' + pluswig
    p = os.popen(cmd, "r")
    LC = 0
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        LC += 1
        if LC % 1000000 == 0:
            print LC, 'lines processed in minus strand file'
        fields = line.strip().split('\t')
        chr = fields[0].split(':')[0]
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[1])
        right = int(fields[2])
        score = float(fields[3])
        for i in range(left,right):
            if i - (K/2) < 0 or i + (K/2) > len(GenomeDict[chr]):
                continue
            kmer = getReverseComplement(GenomeDict[chr][i - (K/2):i + (K/2)])
            if KDict.has_key(kmer):
                pass
            else:
                continue
            KDict[kmer] += score
            TotalScore += score

    print 'finished inputting regions'

    outfile = open(outfilename, 'w')

    outline = '#kmer\tobserved\texpected\tobs/expected'
    outfile.write(outline + '\n')

    KList = KDict.keys()
    KList.sort()

    for kmer in KList:
        Observed = KDict[kmer]/TotalScore
        Expected = kmerGenomeCountsDict[kmer]/TotalKmers
        if Expected == 0:
            outline = kmer + '\t' + str(Observed) + '\t' + str(Expected) + '\t' + 'nan'
        else:
            outline = kmer + '\t' + str(Observed) + '\t' + str(Expected) + '\t' + str(Observed/Expected)
        outfile.write(outline + '\n')

    outfile.close()
            
run()
