##################################
#                                #
# Last modified 2020/07/26       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import gzip
import string
import math
import numpy as np
from sets import Set

def gini(array):
    # """Calculate the Gini coefficient of a numpy array."""
    # based on bottom eq: http://www.statsdirect.com/help/content/image/stat0206_wmf.gif
    # from: http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
    array = array.flatten() #all values are treated equally, arrays must be 1d
    if np.amin(array) < 0:
        array -= np.amin(array) #values cannot be negative
    array += 0.0000001 #values cannot be 0
    array = np.sort(array) #values must be sorted
    index = np.arange(1,array.shape[0]+1) #index per array element
    n = array.shape[0]#number of array elements
    return ((np.sum((2 * index - n  - 1) * array)) / (n * np.sum(array)))

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s fasta K outpurfilename [-noHomoPolymers] [-noSimpleRepeats]' % sys.argv[0]
        print '\tNote: the script will not count kmers on the reverse complements of the input sequences!' 
        print '\tthe [-noSimpleRepeats] option will remove all kmers containing only two letters'
        sys.exit(1)

    fasta = sys.argv[1]
    K = int(sys.argv[2])
    outfilename = sys.argv[3]

    KmerDict = {}

    doNoHP = False
    if '-noHomoPolymers' in sys.argv:
        doNoHP = True
        print 'will omit k-mers with only one letter'

    doNoSR = False
    if '-noSimpleRepeats' in sys.argv:
        doNoSR = True
        print 'will omit k-mers with only one or two letters'

    GenomeDict={}
    sequence=''
    if fasta.endswith('.gz'):
        inputdatafile = gzip.open(fasta)
    else:
        inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    Total = 0.0

    for chr in GenomeDict:
        for i in range(len(GenomeDict[chr]) - K + 1):
            kmer = GenomeDict[chr][i:i+K]
            if doNoHP:
                if len(Set(kmer)) == 1:
                    continue
            if doNoSR:
                if len(Set(kmer)) <= 2:
                    continue
            if KmerDict.has_key(kmer):
                pass
            else:
                KmerDict[kmer] = 0.0
            KmerDict[kmer] += 1
            Total += 1

    FinalList = []
    FinalArray = []

    for kmer in KmerDict.keys():
        FinalList.append((KmerDict[kmer],kmer))
        FinalArray.append(KmerDict[kmer])

    FinalArray = np.array(FinalArray)
    
    FinalList.sort()
    FinalList.reverse()

    outfile = open(outfilename, 'w')

    outline = '#kmer\tCounts\tprobability'
    outfile.write(outline + '\n')

    H = 0

    for (C,kmer) in FinalList:
        p = C/Total
        outline = kmer + '\t' + str(C) + '\t' + str(p)
        outfile.write(outline + '\n')
        H += (-p*math.log(p,2))/math.log(len(FinalList),2)

    outline = 'entropy' + '\t' + str(Total) + '\t' + str(H)
    outfile.write(outline + '\n')

    GINI = gini(FinalArray)

    outline = 'GINI' + '\t' + str(Total) + '\t' + str(GINI)
    outfile.write(outline + '\n')

    outfile.close()
   
run()
