/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.utility;

public class Discretization {
	
	public static class DiscretizationInterval {
		public final double weight;
		public final double startPoint;
		public final double endPoint;
		
		public DiscretizationInterval(double weight, double startPoint, double endPoint) {
			this.weight = weight;
			this.startPoint = startPoint;
			this.endPoint = endPoint;
		}
	}
	
	// compute the expected number of remaining lineages at time t, given n lineages at time 0
	public static double nbarTime(int n, double t, double[] descendingAscendingTable) {
		double sum = 0;
		for (int i=1; i <= n; i++) {
			sum += Math.exp(-i*(i-1)/2 * t) * descendingAscendingTable[i-1] * (2*i-1);
		}
		return sum;
	}
	
	// compute the expected number of remaining lineages at each time in times, given population sizes
	public static double computeNbar(double[] times, double[] sizes, int n, int k, double[] descendingAscendingTable) {
		assert sizes.length == times.length;
		double t = 0;
		for (int j=0; j < k; j++) {
			t += (times[j+1]-times[j]) / sizes[j];
		}
		return nbarTime(n,t,descendingAscendingTable);
	}
	
	public static double[] computeNbarArray(double[] times, double[] sizes, int n, double[] descendingAscendingTable) {
		int d = sizes.length;
		double[] nbarArray = new double[d];
		for (int k=0; k < d; k++) {
			nbarArray[k] = computeNbar(times,sizes,n,k,descendingAscendingTable);
		}
		return nbarArray;
	}
	
	// method for computing the weights for a pre-specified discretization, uses nbar
	public static DiscretizationInterval[] makeDiscretizationPoints(double[] times, double[] sizes, int n, double[] nbarArray) {
		assert sizes.length == times.length;
		int d = sizes.length;

		DiscretizationInterval[] pointsNew = new DiscretizationInterval[d];
		double weightSum = 0;
		double rescaleMargin = 1;
		
		for (int i=0; i < d; i++) {
			double start = times[i];
			double end = i < d-1 ? times[i+1] : Double.POSITIVE_INFINITY;
			double mainWeight = Math.exp(-(end-start)*nbarArray[i]/sizes[i]);
			double newWeight  = rescaleMargin*(1-mainWeight);
			pointsNew[i] = new DiscretizationInterval(newWeight, start, end);
			weightSum += newWeight;
			rescaleMargin *= mainWeight;
		}
		
		// check to make sure the weights add up to 1
		assert (Math.abs(weightSum - 1) < 1E-10);
		return pointsNew;
	}
}