/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import edu.berkeley.utility.Discretization;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Discretization.DiscretizationInterval;

// this class updates the Estep input each iteration, and computes the value of the Mstep function for each candidate size 
// sizes are optimized one PARAMETER at a time, between each optimization the relevant parameters should be updated
public class MstepLinearPac implements MstepLinear {
	
	private boolean DEBUG = false;
	
	// never update
	private final ParamSet pSet;
	private final double[] times;
	private final int d;
	private final int nTotal;
	private final int[] paramPattern;
	private final double[][] descendingAscendingTable;
	
	// update each iteration
    private EstepLinearPac eStep;
    
    // these we will update for each param index
    // NOTE: now we have arrays, one for each trunk size, from min=1 to max=nTotal-1
    private int paramIdx;
    private double[] multFactorArray;
    private double[] rescaledTimeArray;
    private double[] nbarArray; // nbars as of the previous param (i.e. for first param, nbar = nTrunk)
    
    public MstepLinearPac(ParamSet pSet, double[] times, int nTotal, int[] paramPattern, double[][] descendingAscendingTable) {
    	this.pSet = pSet;
    	this.times = times;
    	this.d = times.length;
    	this.nTotal = nTotal;
    	this.paramPattern = paramPattern;
    	this.descendingAscendingTable = descendingAscendingTable;
    	
    	// set up arrays
    	this.multFactorArray = new double[nTotal-1];
    	this.rescaledTimeArray = new double[nTotal-1];
    	this.nbarArray = new double[nTotal-1];
    }

    // compute the log probability of emitting an a from time i (marginalized over haplotypes)
    private double computeLogEmission(int a, CoreLinearOneSize core) {
    	LogSum mutSum = new LogSum(this.pSet.numAlleles());
    	for (int trunkAllele=0; trunkAllele < this.pSet.numAlleles(); trunkAllele++) {
    		mutSum.addLogSummand(core.getLogEmission(trunkAllele, a) + this.pSet.getLogStationaryProb()[trunkAllele]);
    	}
    	return mutSum.retrieveLogSum();
    }

    // we want to MAXIMIZE this function (note: this function should not change any fields)
    public double value(double size) { // this is the size for all intervals spanned by the current param
    	
    	// we are maximizing the total log prob
    	double mStepLogLikelihood = 0.0;
    	int numPerms = this.eStep.getAllExpectedRecoDiffCoalSame().length;
    	 		
    	// translate the paramIdx in the paramPattern into a range of indices in the discretization
    	int startIdx = 0;
    	for (int i = 0; i < this.paramIdx; i++) {
    		startIdx += this.paramPattern[i];
    	}
    	int endIdx = startIdx + this.paramPattern[this.paramIdx];
    	
    	// set up local variables ONCE for each param (we are just testing this size, so don't want to change the globals)
    	double[] localMultFactorArray   = new double[this.nTotal-1];
    	double[] localRescaledTimeArray = new double[this.nTotal-1];
    	double[] localNbarArray         = new double[this.nTotal-1];
    	
    	for (int n=1; n < this.nTotal; n++) {
    		localMultFactorArray[n-1] = this.multFactorArray[n-1];
    		localRescaledTimeArray[n-1] = this.rescaledTimeArray[n-1];
    		localNbarArray[n-1] = this.nbarArray[n-1];
    	}
    	
    	// for each discretization interval, add on the relevant terms
 		for (int j = startIdx; j < endIdx; j++) {
 			
			// for now, just get the start and end points of the interval (weight will be different for each n)
			double startPoint = this.times[j];
			double endPoint = j < this.d-1 ? this.times[j+1] : Double.POSITIVE_INFINITY;
    	
			// first thing: compute all the necessary cores
			// we need one set of probabilities for each trunk size (little n will always be the current trunk size)
			CoreLinearOneSize[] allCores = new CoreLinearOneSize[this.nTotal-1];
			for (int n=1; n < this.nTotal; n++) {
    	
				double mainWeight = Math.exp(-(endPoint - startPoint) * localNbarArray[n-1] / size);
				double newWeight  = localMultFactorArray[n-1] * (1 - mainWeight);
				DiscretizationInterval tPoint = new DiscretizationInterval(newWeight, startPoint, endPoint);
    	
				// define the new core
				allCores[n-1] =	new CoreLinearOneSize(this.pSet, tPoint, size, localNbarArray[n-1]);
				if (DEBUG) {
					System.out.println("\nj=" + j + ", n=" + n);
					System.out.println(allCores[n-1].toString());
					for (int p=0; p < numPerms; p++) { System.out.println(this.eStep.print(p, n, j)); }
				}
				
				// update the local variables within this loop, we'll need them for the new j
				// update local multFactor (do this before updating nbar since we need the previous nbar)
				localMultFactorArray[n-1] *= Math.exp(-(endPoint - startPoint) * localNbarArray[n-1] / size);
			    
			    // update local rescaled time and nbar
				localRescaledTimeArray[n-1] += (endPoint - startPoint) / size;
				localNbarArray[n-1] = Discretization.nbarTime(n, localRescaledTimeArray[n-1], this.descendingAscendingTable[n-1]);
			}
    	
			// for each permutation in turn:
			for (int p=0; p < numPerms; p++) {
				
				// permQprob is the negative log likelihood for the permutation
				double permQprob = 0;
				
				// sum over the transitions and emissions
				for (int n=1; n < this.nTotal; n++) {
					
					// first subtract off the initial marginal probabilities
					permQprob -= this.eStep.getAllExpectedInitialCounts()[p][n-1][j] * allCores[n-1].getLogInitialWeight();
					
					if (DEBUG) {
						System.out.println("---interval " + p + " " + n + " " + j + "---");
						System.out.println("initial: " + (this.eStep.getAllExpectedInitialCounts()[p][n-1][j] * allCores[n-1].getLogInitialWeight()));
					}
						
					// build up the mStep function, not everything is added on for the last time index
					permQprob -= this.eStep.getAllExpectedRecoSameCoalSame()[p][n-1][j] * allCores[n-1].getLogRecoSameCoalSame();
					permQprob -= this.eStep.getAllExpectedNoReco()[p][n-1][j] * allCores[n-1].getLogNoReco();
						
					// for sizeIdx = d-1, all these values should be 0, but for numerical reasons (like 0*-Infinity) they are not, so don't add them on
					if (j < this.d-1) {
						permQprob -= this.eStep.getAllExpectedRecoDiffCoalSame()[p][n-1][j] * allCores[n-1].getLogRecoDiffCoalSame();
						permQprob -= this.eStep.getAllExpectedRecoDiffCoalLater()[p][n-1][j] * allCores[n-1].getLogRecoDiffCoalLater();
						permQprob -= this.eStep.getAllExpectedRecoSameCoalLater()[p][n-1][j] * allCores[n-1].getLogRecoSameCoalLater();
						permQprob -= this.eStep.getAllExpectedCoalNow()[p][n-1][j] * allCores[n-1].getLogCoalNow();
						permQprob -= this.eStep.getAllExpectedCoalLater()[p][n-1][j] * allCores[n-1].getLogCoalLater();
					}
						
					// add emissions
					for (int a=0; a < this.pSet.numAlleles(); a++) {
						permQprob -= this.eStep.getAllExpectedEmissions()[p][n-1][j][a] * computeLogEmission(a,allCores[n-1]);
					}
				}
				
				// first weight by probability of this permutation (divide by the total probability)
				permQprob *= Math.exp(this.eStep.getAllLogLikelihoods()[p] - this.eStep.getEstepLogLikelihood());
				
				// then add (subtract since permQprob is negative log likelihood) the contribution from this permutation
				mStepLogLikelihood -= permQprob;
			}
 		}
    	
	    return mStepLogLikelihood;
    }
    
    // set DEBUG
 	public void setDebug(boolean debug) {
 		this.DEBUG = debug;
 	}
    
    // update BEFORE each EM iteration, to set the eStep and reset relevant values
 	public void updateEachIter(Estep eStep) {
 		this.eStep = (EstepLinearPac) eStep;
 	   	this.paramIdx = 0;
 	   	
 	   	for (int n=1; n < this.nTotal; n++) {
 	   		this.multFactorArray[n-1]   = 1; // starts off at 1 for all trunk sizes
 	   		this.rescaledTimeArray[n-1] = 0; // starts off at 0 for all trunk sizes
 	   		this.nbarArray[n-1]         = n; // starts off at trunk size
 	   	}
 	}
 	
 	// call this once for each parameter after the first one, paramIdx is the index in the paramPattern
 	public void updateParamIdx(int paramIdx, double prevSize) {
 		assert paramIdx > 0;
 		this.paramIdx = paramIdx;
 		
 		// translate the PREVIOUS paramIdx in the paramPattern into a range of indices in the discretization
 		int prevStartIdx = 0;
 		for (int i = 0; i < this.paramIdx-1; i++) {
 		 	prevStartIdx += this.paramPattern[i];
 		}
 		int prevEndIdx = prevStartIdx + this.paramPattern[this.paramIdx-1];
 			
 		// then update the relevant values based on our choice for the PREVIOUS size
 		for (int j = prevStartIdx; j < prevEndIdx; j++) {
 			
 			// find the start and end points of the relevant discretization interval
 			double startPoint = this.times[j];
 			double endPoint = j < this.d-1 ? this.times[j+1] : Double.POSITIVE_INFINITY;
 		
 			for (int n=1; n < this.nTotal; n++) {
 				// update multFactor (do this before updating nbar since we need the previous nbar)
 				this.multFactorArray[n-1] *= Math.exp(-(endPoint - startPoint) * this.nbarArray[n-1] / prevSize);
 			
 				// update rescaled time and nbar
 				this.rescaledTimeArray[n-1] += (endPoint - startPoint) / prevSize;
 				this.nbarArray[n-1] = Discretization.nbarTime(n, this.rescaledTimeArray[n-1], this.descendingAscendingTable[n-1]);
 			}
 		}
 	}
}
