#!/usr/bin/python

"""Module Description

Copyright (c) 2011 Kai Fu  < kelvinfu.tju@gmail.com >

This code is free software; you can redistribute it and/or modify it
under the terms of the Artistic License (see the file COPYING included
with the distribution).

@status:  alpha
@version: $Released$
@author:  Kai Fu
@contact: kelvinfu.tju@gmail.com
"""


# -----------------------------------
# python modules
# -----------------------------------

import os
import logging
import string
import math
import random
from math import sqrt
from array import array

# -----------------------------------
# own python modules
# -----------------------------------
from DINUP.ks_test import ks_2samp


class Peak_Detect:
	
    """ Class to call differential nucleosome positioning regions.
    """
    
    def __init__(self,treat_file=None,control_file=None,name=None,window=None,fdr=None,format=None,bias=None,times=None,region_length=None,asize=None,bsize=None):
	""" Initilize the Peak_Detect object.
	
	Parameters:
	1. treat_file: first file of nucleosome sequencing
	2. control_file: second file of nucleosome sequencing
	3. name: name of the output file
	4. window_size: size of the sliding window
	5. fdr: cutoff of ks_test
	6. format: format of the input file, this version only supports BED
	7. biase: simulated experimental biase
	8. rand: simulation times
	9. region_length: the minimum length of differential nucleosome positioning regions
	"""
	
	self.treat_file = treat_file
	self.control_file = control_file
	self.name = name
	self.window = window
	self.fdr = fdr
	self.format = format
	self.bias = bias
	self.times = times
	self.region_length = region_length
	self.step = 10      # the step of the sliding window
	self.asize= asize
	self.bsize= bsize
	
    def read_tfile(self):
	""" Read treatment file.
	"""
	treat = {}
	try:
	    tfile = open(self.treat_file,'r')
	    
	except:
	    logging.error("No such treatment file!/n")
	    exit()

	i = 0
	try:
	    for line in tfile:
		i += 1
		fileds = line.strip().split('\t')
	    
		if fileds[5] == '+':
		    reads_center = string.atoi(line.split('\t')[1]) + self.asize/2
		elif fileds[5] == '-':
		    reads_center = string.atoi(line.split('\t')[2]) - self.asize/2
	    
		if treat.has_key(fileds[0]):
		    treat[fileds[0]].append(reads_center)
		else:
		    treat[fileds[0]] = [reads_center]
	except:
	    logging.error("Please check whether the treatment file is BED file!/n")
	
	return treat,i
	

    def read_cfile(self):
	""" Read control file
	"""
	control = {}
	try:
	    cfile = open(self.control_file,'r')
	    
	except:
	    logging.error("No such control file!/n")
	    exit()
	
	j = 0
	
	try:
	    for line in cfile:
		j += 1
		fileds = line.strip().split('\t')
		if fileds[5] == '+':
		    reads_center = string.atoi(line.split('\t')[1]) + self.bsize/2
		elif fileds[5] == '-':
		    reads_center = string.atoi(line.split('\t')[2]) - self.bsize/2
	    
		if control.has_key(fileds[0]):
		    control[fileds[0]].append(reads_center)
		else:
		    control[fileds[0]] = [reads_center]
	except:
	    logging.error("Please check whether the control file is BED file!/n")
	
	return control,j
    
	
    def get_chr_length(self,treat,control):
	""" Sort the reads' center in order to get the effective genome size
	"""
	key_treat = []   # the chromosome that treat file has
	key_control = []  # the chromosome that control file has
	chr_length = {}
	
	for key in treat.keys():
	    key_treat.append(key)    # contain the key in order to calculate the chromosome both treat and control share
	    treat[key].sort()    # sort the reads of one chromosome in order to get the length of this chromosome
	
	for key in control.keys():
	    key_control.append(key)
	    control[key].sort()
	
	key_share = list(set(key_treat) & set(key_control))  # get the shared chromosome of treat and control file
	
	for key in key_share:
	    if treat[key][-1] > control[key][-1]:
		chr_length[key] = treat[key][-1] + 100
	    else:
		chr_length[key] = control[key][-1] + 100
		
	return chr_length

    def genomic_coverage(self , i, j, chr_length):
	""" Calculate the genomic coverage using the formula coverage = 146 * total tag number / (effective genome size)
	"""
	
	gsize = 0    # the effective genomic size
	for key in chr_length.keys():
	    gsize += chr_length[key]      
	
	if float(146 * i)/gsize - int(float(146 * i)/gsize) >= 0.5:
	    coverage_treat = int(float(146 * i)/gsize) + 1
	else:
	    coverage_treat = int(float(146 * i)/gsize)
	
	if float(146 * j)/gsize - int(float(146 * j)/gsize) >= 0.5:
	    coverage_control = int(float(146 * j)/gsize) + 1
	else:
	    coverage_control = int(float(146 * j)/gsize)
    
	return coverage_treat,coverage_control
    
    
    def prepare_list(self,chr_length):
	""" Prepare for the list that containing data to do ks_test
	"""
	distance_treat = {}
	distance_control = {}
	
	for key in chr_length.keys():
	    distance_treat[key] = [array('i',[]) for k in range(0, chr_length[key], 10)]
	    distance_control[key] = [array('i',[]) for k in range(0, chr_length[key], 10)]
	    
	return distance_treat, distance_control

    
    def sliding_window(self, treat, control, distance_treat, distance_control):
	""" using sliding window to contain distances between reads' center and sliding windows' center
	"""
	try:
		
	    for key in distance_treat.keys():
		for i in xrange(len(treat[key])):
		    start = treat[key][i]/self.step - self.window/self.step + 1    # the start of the sliding window
		    end = treat[key][i]/self.step      # the end of the sliding window
		    for j in xrange(start, end+1):     # for each window, calculate the distance between reads' center and window's center
			distance_treat[key][j].append(treat[key][i] - (j*self.step + self.window/2))
	
	except:
	    logging.error("Memory error when using sliding window!")        # if so, please use a computer with a larger memory
	
	try:
	    for key in distance_control.keys():
		for i in xrange(len(control[key])):
		    start = control[key][i]/self.step - self.window/self.step + 1
		    end = control[key][i]/self.step
		    for j in xrange(start, end+1):
			distance_control[key][j].append(control[key][i] - (j*self.step + self.window/2))
	except:
	    logging.error("Memory error when using sliding window!")
		    
	for key in distance_treat.keys():
	    for i in xrange(len(distance_treat[key])):
		if distance_treat[key][i] == []:
		    distance_treat[key][i] = [0]      # if the window doesn't contain any read, then it is [0]
	    for i in xrange(len(distance_control[key])):
		if distance_control[key][i] == []:
		    distance_control[key][i] = [0]
	return distance_treat, distance_control
	
    
    def get_cutoff(self,distance_treat, distance_control, chr_length):
	""" estimate background p-value as FDR to be the cutoff
	"""
	
	choose_random_treat = []
	choose_random_control = []
	pvalue_random = []
	
	gsize = 0    # the effective genomic size
	for key in chr_length.keys():
	    gsize += chr_length[key]
	i = 0
	
	times = int(gsize/100)
	
	while i <= times:      # randomly choose 1% genome to estimate the FDR
	    
	    key_random = random.randint(0,len(distance_treat.keys())-1)       # choose a random key
	    list_random = random.randint(0, chr_length[distance_treat.keys()[key_random]]/self.step-1)        # choose a random list of the random key

	    distance_treat_random = distance_treat[distance_treat.keys()[key_random]][list_random]        # the random list of the distance_treat
	    distance_control_random = distance_control[distance_control.keys()[key_random]][list_random]      # the random list of the distance_control
	    
	    if distance_treat_random != [0] or distance_control_random != [0] :        # choose two lists which at least one has one read
		combine_random_distance = distance_treat_random + distance_control_random     # combine two lists
		i += 1
		
		for j in xrange(len(distance_treat_random)):          # randomly choose from the combine list according to the previous length of distance_treat_random
		    choose_random_int = random.randint(0, len(distance_treat_random)-1-j)
		    choose_random_treat.append(combine_random_distance.pop(choose_random_int))
		choose_random_control = combine_random_distance
		
		pvalue_random.append(ks_2samp(sorted(choose_random_treat), sorted(choose_random_control))[1])
		choose_random_treat = []
		choose_random_control = []
		
	    else:
		pass

	pvalue_random.sort()
	pvalue_cutoff = pvalue_random[int(times*self.fdr)]
	
	if pvalue_cutoff >= 0.0001:
	    pvalue_cutoff = -10*math.log10(pvalue_cutoff)
	else:
	    pvalue_cutoff = -10*math.log10(pvalue_cutoff)
	    pvalue_cutoff = int("%.0f"%pvalue_cutoff)
	
	return pvalue_cutoff
    
    def call_peaks(self, chr_length, distance_treat, distance_control, cutoff):
	""" call differential nucleosome positioning peaks
	"""
	
        pvalue_adjust_random=[]     # contain the repeat adjust ks-test's pvalue for one window
	pvalue_adjust = {}
	location = {}      # contain location information of the dicitionary
	
	# Sort the dictionary's key 
	key_chr_length = chr_length.keys()
	key_chr_length.sort()
	
	if cutoff == 0:
	    pvalue_cutoff_log = 0
	else:
	    pvalue_cutoff_log = -10 * math.log10(cutoff)
	
	for key in key_chr_length:
	    logging.info("calculating %s" % (key))	
	    
	    for i in xrange(min(len(distance_treat[key]), len(distance_control[key]))):
		try:
		    pvalue = ks_2samp(sorted(distance_treat[key][i]), sorted(distance_control[key][i]))[1]
		    pvalue_log = -10 * math.log10(pvalue)
		except:
		    pvalue_log = 3250
		
		if pvalue_log > pvalue_cutoff_log:
		    for randomtimes in xrange(self.times): 
			distance_treat_rep = distance_treat[key][i]      # contain the distance which has been given coordinate disturbances
			distance_control_rep = distance_control[key][i]
			for random_treat in xrange(len(distance_treat[key][i])):
			    distance_treat_rep[random_treat] = random.randint(-1*self.bias,self.bias) + distance_treat[key][i][random_treat]     #random integer between -3 and +3
			for random_control in xrange(len(distance_control[key][i])):
			    distance_control_rep[random_control] = random.randint(-1*self.bias,self.bias) + distance_control[key][i][random_control]
			try:
			    pvalue_adjust_random.append( -10 * math.log10(ks_2samp(sorted(distance_treat_rep), sorted(distance_control_rep))[1]))
			except:
			    pvalue_adjust_random.append( 3250 )
		    
		    if pvalue_adjust.has_key(key):
			pvalue_adjust[key].append(sum(pvalue_adjust_random)/self.times)
		    else:
			pvalue_adjust[key] = [sum(pvalue_adjust_random)/self.times]
		    
		    pvalue_adjust_random = []   
		else:
		    if pvalue_adjust.has_key(key):
			pvalue_adjust[key].append(0)
		    else:
			pvalue_adjust[key] = [0]
		if location.has_key(key):
		    location[key].append(self.step*i + self.window/2)
		else:
		    location[key] = [self.step*i + self.window/2]
		
	return pvalue_adjust, location
    
    
    def call_peaks_fold(self, chr_length, distance_treat, distance_control, fold_cutoff, coverage_treat, coverage_control):
	""" call differential nucleosome positioning peaks based on fold change method
	"""
	
        pvalue_adjust = {}
	location = {}      # contain location information of the dicitionary
	
	# Sort the dictionary's key 
	key_chr_length = chr_length.keys()
	key_chr_length.sort()
	
	for key in key_chr_length:
	    logging.info("calculating %s" % (key))	
	    
	    for i in xrange(min(len(distance_treat[key]), len(distance_control[key]))):
		if len(distance_treat[key][i]) <=1 or len(distance_control[key][i]) <=1:
		    if pvalue_adjust.has_key(key):
                        pvalue_adjust[key].append(0)
                    else:
                        pvalue_adjust[key] = [0]
		else:
		    fold_change = (len(distance_treat[key][i])*1.0/len(distance_control[key][i]))/(coverage_treat*1.0/coverage_control)
		    if fold_change > fold_cutoff or fold_change < 1.0/fold_cutoff:
			if pvalue_adjust.has_key(key):
			    pvalue_adjust[key].append(fold_change)
			else:
			    pvalue_adjust[key] = [fold_change]
		    else:
			if pvalue_adjust.has_key(key):
			    pvalue_adjust[key].append(0)
			else:
			    pvalue_adjust[key] = [0]
		if location.has_key(key):
		    location[key].append(self.step*i + self.window/2)
		else:
		    location[key] = [self.step*i + self.window/2]

	return pvalue_adjust, location
    
    
    def to_wig(self,pvalue_adjust,location):
	""" output KS-test's adjust p-value to wiggle file
	"""
	
	wig_name = os.path.join('DiNuP_results',self.name+'_dnpr.wig')
	wig_file = open(wig_name , 'w')
	key_wig = pvalue_adjust.keys()
	key_wig.sort()
	
	for key in key_wig:
	    wig_file.write("track type=wiggle_0 " + "name=" + str(key) + '\n') # data type line
	    wig_file.write("fixedStep  chrom=%s  "%(key) + "  start=" + str(self.window/2) + "  step=" + str(self.step) + '\n')
	    for i in range(len(pvalue_adjust[key])):
		wig_file.write(str(pvalue_adjust[key][i]) + '\n')
	
	wig_file.close()
    
    
    def get_peak_region(self,pvalue_adjust,cutoff):
	""" smooth and merge candidate peaks, get the final differential nucleosome positioning regions
	"""
	
	local_region = []
	region_number = 0
	pvalue_adjust_smooth = {}
	dynamic_region_start = {}
	dynamic_region_end = {}
	dynamic_region = {}
	pvalue_largest = {}
	pvalue_location = {}
	
	
	if cutoff != 0:
	    pvalue_cutoff_log = -10 * math.log10(cutoff)
	else:
	    pvalue_cutoff_log = 0
	
	
	for key in pvalue_adjust.keys():        # smooth the ks-test's adjust p-value
	    for i in xrange(len(pvalue_adjust[key])):
		if i >=1 and i < len(pvalue_adjust[key])-1:
		    pvalue_smooth = (pvalue_adjust[key][i-1] + pvalue_adjust[key][i+1])/2
		else:
		    pvalue_smooth = pvalue_adjust[key][i]
		if pvalue_adjust_smooth.has_key(key):
		    pvalue_adjust_smooth[key].append(pvalue_smooth)
		else:
		    pvalue_adjust_smooth[key] = [pvalue_smooth]
	
	
	for key in pvalue_adjust_smooth.keys():
	    for i in xrange(len(pvalue_adjust_smooth[key])):
		if pvalue_adjust_smooth[key][i] >= pvalue_cutoff_log:
		    local_region.append(pvalue_adjust_smooth[key][i])
		elif pvalue_adjust_smooth[key][i] < pvalue_cutoff_log and len(local_region) >= self.region_length/10 :
		    
		    max_number = 0
		    for l in xrange(len(local_region)):
			if local_region[l] > max_number:
			    max_location = l
			    max_number = local_region[l]
			else:
			    pass
		    
		    local_region.sort()
		    n = len(local_region)     # the length of the un-merged region
		    m = float(sum(local_region))/n
		    
		    start = 10*i - 10*n
		    end = self.window + 10*i
		    pvalue_largest_location = start + int(self.window/2) + max_location*10
		    
		    if dynamic_region_start.has_key(key):
			dynamic_region_start[key].append(start)
			dynamic_region_end[key].append(end)
			dynamic_region[key].append(end - start)
			pvalue_largest[key].append(local_region[-1])
			pvalue_location[key].append(pvalue_largest_location)
		    
		    else:
			dynamic_region_start[key] = [start]
			dynamic_region_end[key] = [end]
			dynamic_region[key] = [end - start]
			pvalue_largest[key] = [local_region[-1]]
			pvalue_location[key] = [pvalue_largest_location]
			
		    if region_number >= 1:        # merge adjacent peak region
			if dynamic_region_start[key][-1] <= dynamic_region_end[key][-2]:
			    dynamic_region_end[key][-2] = dynamic_region_end[key][-1]
			    del dynamic_region_start[key][-1]
			    del dynamic_region_end[key][-1]
			    if pvalue_largest[key][-1] > pvalue_largest[key][-2]:
				pvalue_largest[key][-2] = pvalue_largest[key][-1]
				pvalue_location[key][-2] = pvalue_location[key][-1]
			    del pvalue_largest[key][-1]
			    del pvalue_location[key][-1]
			    dynamic_region[key][-2] = dynamic_region_end[key][-1] - dynamic_region_start[key][-1]
			    del dynamic_region[key][-1]
		    region_number += 1
		    local_region = []       # initialize local_region[key]
		else:
		    local_region = []
	    region_number = 0        # initialize region number, each chromosome has one region_number
	
	return dynamic_region_start, dynamic_region_end, dynamic_region, pvalue_largest, pvalue_location
    
    
    def get_peak_region_fold(self,pvalue_adjust,cutoff):
	""" smooth and merge candidate peaks, get the final differential nucleosome positioning regions
	"""
	
	local_region = []
	region_number = 0
	pvalue_adjust_smooth = {}
	dynamic_region_start = {}
	dynamic_region_end = {}
	dynamic_region = {}
	pvalue_largest = {}
	pvalue_location = {}
	
	pvalue_cutoff_log = cutoff
	
	pvalue_adjust_smooth = pvalue_adjust
	
	for key in pvalue_adjust_smooth.keys():
	    for i in xrange(len(pvalue_adjust_smooth[key])):
		if pvalue_adjust_smooth[key][i] != 0:
		    local_region.append(pvalue_adjust_smooth[key][i])
		elif pvalue_adjust_smooth[key][i] == 0 and len(local_region) >= self.region_length/10:
		    max_number = 0
		    for l in xrange(len(local_region)):
			if local_region[l] > max_number:
			    max_location = l
			    max_number = local_region[l]
			else:
			    pass
		    
		    local_region.sort()
		    n = len(local_region)     # the length of the un-merged region
		    m = float(sum(local_region))/n
		     
		    start = 10*i - 10*n
		    end = self.window + 10*i
		    pvalue_largest_location = start + int(self.window/2) + max_location*10
		    
		    if dynamic_region_start.has_key(key):
			dynamic_region_start[key].append(start)
			dynamic_region_end[key].append(end)
			dynamic_region[key].append(end - start)
			pvalue_largest[key].append(local_region[-1])
			pvalue_location[key].append(pvalue_largest_location)
		    
		    else:
			dynamic_region_start[key] = [start]
			dynamic_region_end[key] = [end]
			dynamic_region[key] = [end - start]
			pvalue_largest[key] = [local_region[-1]]
			pvalue_location[key] = [pvalue_largest_location]
			
		    if region_number >= 1:        # merge adjacent peak region
			if dynamic_region_start[key][-1] <= dynamic_region_end[key][-2]:
			    dynamic_region_end[key][-2] = dynamic_region_end[key][-1]
			    del dynamic_region_start[key][-1]
			    del dynamic_region_end[key][-1]
			    if pvalue_largest[key][-1] > pvalue_largest[key][-2]:
				pvalue_largest[key][-2] = pvalue_largest[key][-1]
				pvalue_location[key][-2] = pvalue_location[key][-1]
			    del pvalue_largest[key][-1]
			    del pvalue_location[key][-1]
			    dynamic_region[key][-2] = dynamic_region_end[key][-1] - dynamic_region_start[key][-1]
			    del dynamic_region[key][-1]
		    region_number += 1
		    local_region = []       # initialize local_region[key]
		
		else:
		    local_region = []
	    region_number = 0        # initialize region number, each chromosome has one region_number
	
	return dynamic_region_start, dynamic_region_end, dynamic_region, pvalue_largest, pvalue_location
    
    
    def to_bed(self, dynamic_region_start, dynamic_region_end, dynamic_region, pvalue_largest):
	""" Write genomic coordinates to output
	"""
	output_file_name = os.path.join('DiNuP_results',self.name+'_dinup.bed')
	output_file = open(output_file_name , 'w')
	#output_file.write('chromosome' + '\t' + 'start' + '\t' + 'end' + '\t' + 'length' + '\t' + '-10*log10(pvalue)' + '\n' )
	
	key_dynamic_region_start = dynamic_region_start.keys()
	key_dynamic_region_start.sort()
	
	for key in key_dynamic_region_start:
	    for i in xrange(len(dynamic_region_start[key])):
		output_file.write(str(key) + '\t' + str(dynamic_region_start[key][i]) + '\t' + str(dynamic_region_end[key][i]) + '\t' + \
		str(dynamic_region[key][i]) + '\t' + str(round(pvalue_largest[key][i],2)) + '\n')
	output_file.close()
    
    def to_bed_fold(self, dynamic_region_start, dynamic_region_end, dynamic_region, pvalue_largest):
	""" Write genomic coordinates to output
	"""
	output_file_name = os.path.join('DiNuP_results',self.name+'_fold.bed')
	output_file = open(output_file_name , 'w')
	#output_file.write('chromosome' + '\t' + 'start' + '\t' + 'end' + '\t' + 'length' + '\t' + '-10*log10(pvalue)' + '\n' )
	
	key_dynamic_region_start = dynamic_region_start.keys()
	key_dynamic_region_start.sort()
	
	for key in key_dynamic_region_start:
	    for i in xrange(len(dynamic_region_start[key])):
		output_file.write(str(key) + '\t' + str(dynamic_region_start[key][i]) + '\t' + str(dynamic_region_end[key][i]) + '\t' + \
		str(dynamic_region[key][i]) + '\t' + str(round(pvalue_largest[key][i],2)) + '\n')
	output_file.close()
	
    def to_xls(self, dynamic_region_start, dynamic_region_end, dynamic_region, pvalue_largest, pvalue_location, argtxt):
	""" Write genomic coordinates, effective width, pvalue and summits to output
	"""
	output_file_name = os.path.join('DiNuP_results',self.name+'_dinup.xls')
	output_file = open(output_file_name , 'w')
	output_file.write('\n' + argtxt)
	output_file.write('\n')
	output_file.write('chromosome' + '\t' + 'start' + '\t' + 'end' + '\t' + 'effective width' + '\t' + '-10*log10(pvalue)' + '\t' + 'summit' + '\n')
	key_dynamic_region_start = dynamic_region_start.keys()
	key_dynamic_region_start.sort()
	
	for key in key_dynamic_region_start:
	    for i in xrange(len(dynamic_region_start[key])):
		output_file.write(str(key) + '\t' + str(dynamic_region_start[key][i]) + '\t' + str(dynamic_region_end[key][i]) + '\t' + \
		str(dynamic_region[key][i]) + '\t' + str(round(pvalue_largest[key][i],2)) + '\t' + str(pvalue_location[key][i]) + '\t' + '\n')
	
	output_file.close()
    
    
    def to_xls_with_feature(self, dynamic_region_start, dynamic_region_end, dynamic_region, pvalue_largest, pvalue_location, fold_change,positioning_degree_variation, positioning_variation,argtxt):
	""" Write additional features to output
	"""
	output_file_name = os.path.join('DiNuP_results',self.name+'_dinup.xls')
	output_file = open(output_file_name , 'w')
	output_file.write('\n' + argtxt)
	output_file.write('\n')
	output_file.write('chromosome' + '\t' + 'start' + '\t' + 'end' + '\t' + 'effective width' + '\t' + '-10*log10(pvalue)' + '\t' + 'summit' \
		+ '\t' +'repositioned variation' + '\t' + 'occupancy change' + '\t' + 'positioning degree change' + '\n')
	
	key_dynamic_region_start = dynamic_region_start.keys()
	key_dynamic_region_start.sort()
	
	for key in key_dynamic_region_start:
	    for i in xrange(len(dynamic_region_start[key])):
		output_file.write(str(key) + '\t' + str(dynamic_region_start[key][i]) + '\t' + str(dynamic_region_end[key][i]) + '\t' + \
		str(dynamic_region[key][i]) + '\t' + str(round(pvalue_largest[key][i],2)) + '\t' + str(pvalue_location[key][i]) + '\t' + \
		str(positioning_variation[key][i]) + '\t' + str(round(fold_change[key][i],2)) + '\t' + str(round(positioning_degree_variation[key][i],2)) + '\n')
	
	output_file.close()

