#!/usr/bin/env python
# Copyright (c) 2008 The George Washington University
# Copyright (c) 2007 NHLBI, NIH
# Authors: Chongzhi Zang, Weiqun Peng, Dustin E Schones and Keji Zhao
#
# This software is distributable under the terms of the GNU General
# Public License (GPL) v2, the text of which can be found at
# http://www.gnu.org/copyleft/gpl.html. Installing, importing or
# otherwise using this module constitutes acceptance of the terms of
# this License.
#
# Disclaimer
# 
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# General Public License for more details.
#
# 
import re, os, sys, shutil
from math import *   
from string import *
from optparse import OptionParser
import operator

import BED
import GenomeData
import get_total_tag_counts
import Background_island_probscore_statistics

""" 
Take in coords for bed_gaph type summary files and find 'islands' of modifications.
There are a number options here that can be turned on or off depending on need
Right now:

(1) scan for all 'islands' where the single window or consecutive
windows

(2) remove all single window 'islands' -- this can be commented out
when looking for more localized signals (such as TF binding sites?)

(3) try to combine islands that are within gap distance of each other.
This gap distance is supplemented by a window_buffer just so we don't
do anything stupid with window sizes

(4) Remove all single window combined islands -- if step (2) was done,
this is redundant

(5) Lastly, filter out all the islands we've found that have a total
score < islands_minimum_tags
"""

##  parameter
window_buffer = 10;

# Factorial
def fact(m):
	value = 1.0;
	if m != 0:
		while m != 1:
			value = value*m;
			m = m - 1;
	return value;

# Return the log of a factorial, using Srinivasa Ramanujan's approximation when m>=20
def factln(m):
	if m<20:  
		value = 1.0;
		if m != 0:
			while m != 1:
				value = value*m;
				m = m - 1;
		return log(value);
	else:
		return m*log(m) -m + log(m*(1+4*m*(1+2*m)))/6.0 + log(pi)/2;

def poisson(i, average):
	if i<20:
		return exp(-average) * average**i / fact(i);
	else:
		exponent = -average + i*log(average) - factln(i);
		return exp(exponent);
	

def find_threshold(pvalue, average):
	"""
	Returns the thershold value T given the p-value requirement.
	Namely, P(T)+P(T+1)+ ... +P(infty) <p-vlue
	"""
	value = 1;
	index = 0;
	value -= poisson(index, average);
	while value > pvalue:
		index += 1;
		value -= poisson(index, average);	
	# index I is the value of which 1-P(0)-P(1)-...-P(I) < pvalue
	return index+1;



def combineProximalIslands(islands, gap):
    """
    Extend the islands found with the findIslands function
    """
    proximal_island_dist = gap + window_buffer;
    combined_islands = [];
    island_index = 0;
    got_last = 0;
    while island_index < (len(islands) - 1):
        i = islands[island_index];
        next_i = islands[island_index+1];
        start = i.start;
        end = i.end;
        score = i.value;
        extend = 1;
        while extend:
            if island_index < (len(islands) - 1):
                i = islands[island_index];
                next_i = islands[island_index+1];
                next_dist = abs(next_i.start - i.end);
                if next_dist <= proximal_island_dist:
                    end = next_i.end;
                    score += next_i.value;
                    island_index += 1;
                else:
                    extend = 0;
            else:
                extend = 0;
                got_last = 1;
        island_index += 1;
        whole_island = BED.BED_GRAPH(i.chrom, start, end, score);
        combined_islands.append(whole_island);
    if got_last == 0:
        combined_islands.append(islands[-1]);
    return combined_islands;


def removeSingleWindowIslands(islands, window):
    filtered_islands = [];
    for i in islands:
        size = i.end - i.start;
        if size > window:
            filtered_islands.append(i);
    return filtered_islands;



def findIslands(bed_vals, chrom, outfilename, window, gap, islands_minimum_tags):
    """
    Find all islands, which are _consecutive_ windows scoring about threshold
    This one integrated the implementation of islands_minimum_tags, and the output is a file.
    """
    islands = [];
    if chrom in bed_vals.keys():
        tags = bed_vals[chrom];
        tags.sort(key=operator.attrgetter('start'));
        tag_index = 0;
        got_last = 0;
        while tag_index < (len(tags) - 1):
            island_start = tags[tag_index].start;
            island_end = tags[tag_index].end;
            island_value = tags[tag_index].value;
            extend = 1;
            while extend:
                if tag_index < (len(tags) - 1):
                    next_distance = abs(tags[tag_index+1].start - tags[tag_index].end);
                    if next_distance < window_buffer:
                        island_end = tags[tag_index+1].end
                        island_value += tags[tag_index+1].value
                        tag_index += 1;
                    else:
                        island = BED.BED_GRAPH(chrom, island_start, island_end, island_value);
                        islands.append(island);
                        extend = 0;
                else:
                    got_last = 1;
                    island = BED.BED_GRAPH(chrom, island_start, island_end, island_value);
                    islands.append(island);
                    extend = 0;
            tag_index += 1;

        """ add last on if didn't get there """
        if got_last == 0:
            island_start = tags[tag_index].start;
            island_end = tags[tag_index].end;
            island_value = tags[tag_index].value;

            island = BED.BED_GRAPH(chrom, island_start, island_end, island_value);
            islands.append(island);


        """ remove all single window islands """
        #islands = removeSingleWindowIslands(islands, window);

        """ combine all islands within """
        combined_islands = combineProximalIslands(islands, gap);

        """ remove all single window islands """
        combined_islands = removeSingleWindowIslands(combined_islands, window);

        
        """ filter islands by tag counts and print out """
        outfile = open(outfilename, 'w');
        for i in combined_islands:
            if i.value >= islands_minimum_tags:
                outline = chrom + " " + str(i.start) + " " + str(i.end) + " " + str(i.value) + "\n";
                outfile.write(outline);
        outfile.close();


def combine_proximal_regions(islands, gap, window_size_buffer=10):
    """
    Extend the regions found in the find_continuous_region function if
    gap is not allowed, gap = 0, if one window is allowed, gap =
    window_size (200) return a list of combined regions.
    """
    proximal_island_dist = gap + window_size_buffer;
    combined_islands = [];
    island_index = 0;
    got_last = 0;
  
    while island_index < (len(islands) - 1):
        i = islands[island_index];
        next_i = islands[island_index+1];
        start = i.start;
        end = i.end;
        score = i.value;
        extend = 1;
        while extend:
            if island_index < (len(islands) - 1):
                i = islands[island_index];
                next_i = islands[island_index+1];
                next_dist = abs(next_i.start - i.end);
                if next_dist <= proximal_island_dist:
                    end = next_i.end;
                    score += next_i.value;
                    island_index += 1;
                else:
                    extend = 0;
            else:
                extend = 0;
                got_last = 1;
        island_index += 1;
        whole_island = BED.BED_GRAPH(i.chrom, start, end, score);
        combined_islands.append(whole_island);
    if got_last == 0:
        combined_islands.append(islands[-1]);
    return combined_islands;



def find_region(bed_vals, chrom, gap, window_size, window_size_buffer=10):
    """
    bed_graph_file is the bed_graph type summary file
    
    Find all regions made of consecutive windows. The requirement
    for each window to score above tag_count_threshold_in_window
    is implemented by filtering involved in the making of bed_vals
    
    return a list of (chrom, start, end, score)
    """  
     
    islands = [];
    if chrom in bed_vals.keys():
        tags = bed_vals[chrom];
        if (len(tags) >0):
            tags.sort(key=operator.attrgetter('start'));
            tag_index = 0;
            got_last = 0;
            while tag_index < (len(tags) - 1):
                island_start = tags[tag_index].start;
                island_end = tags[tag_index].end;
                island_value = tags[tag_index].value;
                extend = 1;
                while extend:
                    if tag_index < (len(tags) - 1):
                        next_distance = abs(tags[tag_index+1].start - tags[tag_index].end);
                        if next_distance < window_size_buffer:
                            island_end = tags[tag_index+1].end
                            island_value += tags[tag_index+1].value
                            tag_index += 1;
                        else:
                            island = BED.BED_GRAPH(chrom, island_start, island_end, island_value);
                            islands.append(island);
                            extend = 0;
                    else:
                        got_last = 1;
                        island = BED.BED_GRAPH(chrom, island_start, island_end, island_value);
                        islands.append(island);
                        extend = 0;
                tag_index += 1;

            # add last on if didn't get there
            if got_last == 0:
                #print chrom, tag_index;
                island_start = tags[tag_index].start;
                island_end = tags[tag_index].end;
                island_value = tags[tag_index].value;
                island = BED.BED_GRAPH(chrom, island_start, island_end, island_value);
                islands.append(island);

            #combine all islands within 
            combined_islands =combine_proximal_regions(islands, gap, window_size_buffer);
			
            #""" remove all single window islands """
            #combined_islands = removeSingleWindowIslands(combined_islands, window_size);

			
        else: combined_islands = [];	
        return combined_islands;
    else:
        print "Chromosome number not right!!";



def find_region_above_threshold(island_list, islands_minimum_tags):
    filtered_islands = [];
    for island in island_list:
        if island.value >= (islands_minimum_tags-.0000000001): filtered_islands.append(island);
    return filtered_islands;


def find_region_above_threshold_from_file(Islands_file, species, islands_minimum_tags, out_islands_file):
    bed_vals = BED.BED(species, Islands_file, "BED_GRAPH");
    outputfile = open(out_islands_file, 'w');
    total_number_islands = 0.0;
    total_tags_on_islands = 0.0;
    for chrom in bed_vals.keys():
        islands = bed_vals[chrom];
        islands = find_region_above_threshold(islands, islands_minimum_tags);
        total_number_islands += len(islands);
        for i in islands:
            outline = chrom + " " + str(i.start) + " " + str(i.end) + " " \
                      + str(i.value) + "\n";	
            outputfile.write(outline);
            total_tags_on_islands += i.value;
    outputfile.close();



def main(argv):
	"""
	Probability scoring with random background model.
	
	"""
	parser = OptionParser()
	
	parser.add_option("-s", "--species", action="store", type="string", dest="species", help="mm8, hg18, background, etc", metavar="<str>")
	parser.add_option("-b", "--summarygraph", action="store",type="string", dest="summarygraph", help="summarygraph", metavar="<file>")
	parser.add_option("-w", "--window_size(bp)", action="store", type="int", dest="window_size", help="window_size(in bps)", metavar="<int>")
	parser.add_option("-g", "--gap_size(bp)", action="store", type="int",  dest="gap", help="gap size (in bps)", metavar="<int>")
	parser.add_option("-t", "--mappable_fraction_of_genome_size ", action="store", type="float", dest="fraction", help="mapable fraction of genome size", metavar="<float>")
	parser.add_option("-e", "--evalue ", action="store", type="float", dest="evalue", help="evalue that determines score threshold for significant islands", metavar="<float>")
	parser.add_option("-f", "--out_island_file", action="store", type="string", dest="out_island_file", help="output island file name", metavar="<file>")
	
	(opt, args) = parser.parse_args(argv)
	if len(argv) < 14:
        	parser.print_help()
        	sys.exit(1)

	if opt.species in GenomeData.species_chroms.keys():
		print "Species: ", opt.species;
		print "Window_size: ", opt.window_size;
		print "Gap size: ", opt.gap;
		print "E value is:", opt.evalue;
		
		total_read_count = get_total_tag_counts.get_total_tag_counts_bed_graph(opt.summarygraph);
		print "Total read count:", total_read_count
		genome_length = sum (GenomeData.species_chrom_lengths[opt.species].values());
		genome_length = int(opt.fraction * genome_length);

		average = float(total_read_count) * opt.window_size/genome_length; 
		print "Effective genome Length: ", genome_length;
		print "Window average:", average;
		
		window_pvalue = 0.20;
		bin_size = 0.001;
		print "Window pvalue:", window_pvalue;
		background = Background_island_probscore_statistics.Background_island_probscore_statistics(total_read_count, opt.window_size, opt.gap, window_pvalue, genome_length, bin_size);
		min_tags_in_window = background.min_tags_in_window
		print "Minimum num of tags in a qualified window: ", min_tags_in_window
		
		print "Generate the enriched probscore summary graph file"; 
		#read in the summary graph file
		bed_val = BED.BED(opt.species, opt.summarygraph, "BED_GRAPH");
		#generate the probscore summary graph file, only care about enrichment
		for chrom in bed_val.keys():
			if len(bed_val[chrom])>0:
				for index in xrange(len(bed_val[chrom])):
					read_count = bed_val[chrom][index].value;
					if ( read_count < min_tags_in_window):
						score = 0;
					else:
						prob = poisson(read_count, average);
						if prob <1e-250:
							score = 1000; #outside of the scale, take an arbitrary number.
						else:
							score = -log(prob);
					bed_val[chrom][index].value = score;
					#print chrom, start, read_count, score;
		
		#write the probscore summary graph file
		#Background_simulation_pr.output_bedgraph(bed_val, opt.out_sgraph_file);
		
		print "Filter the summary graph to get rid of ineligible windows ";
		#filter the summary graph to get rid of windows whose scores are less than window_score_threshold
		filtered_bed_val = {};
		for chrom in bed_val.keys():
			if len(bed_val[chrom])>0:
				filtered_bed_val [chrom]= [];
				for item in bed_val[chrom]:
					if item.value>0:
						filtered_bed_val[chrom].append(item);
		
		#Background_simulation_pr.output_bedgraph(filtered_bed_val, opt.out_sgraph_file+".filtered");
		
		print "Determine the score threshold from random background"; 
		#determine threshold from random background
		hist_outfile="L" + str(genome_length) + "_W" +str(opt.window_size) + "_G" +str(opt.gap) +  "_s" +str(min_tags_in_window) + "_T"+ str(total_read_count) + "_B" + str(bin_size) +"_calculatedprobscoreisland.hist";
		score_threshold = background.find_island_threshold(opt.evalue); 
		# background.output_distribution(hist_outfile);
		print "The score threshold is: ", score_threshold;
		
		
		print "Make and write islands";
		#make and write islands
		total_number_islands = 0;
		outputfile = open(opt.out_island_file, 'w');
		for chrom in filtered_bed_val.keys():
			if len(filtered_bed_val[chrom])>0:
				islands = find_region(filtered_bed_val, chrom, opt.gap, opt.window_size, 10);
				islands = find_region_above_threshold(islands, score_threshold);
				total_number_islands += len(islands);
			for i in islands:
				outline = chrom + "\t" + str(i.start) + "\t" + str(i.end) + "\t" + str(i.value) + "\n";	
				outputfile.write(outline);
		outputfile.close();	
		print "Total number of islands: ", total_number_islands;
		
	else:
		print "This species is not in my list!"; 

if __name__ == "__main__":
	main(sys.argv)
