/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import edu.berkeley.utility.Discretization;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Discretization.DiscretizationInterval;

// this class updates the Estep input each iteration, and computes the value of the Mstep function for each candidate size 
// sizes are optimized one PARAMETER at a time, between each optimization the relevant parameters should be updated
public class MstepLinearLol implements MstepLinear {
	
	private boolean DEBUG = false;
	
	// never update
	private final ParamSet pSet;
    private final double[] times;
    private final int d;
    private final int nTrunk;
    private final int[] paramPattern;
    private final double[] descendingAscendingTable;
    
    // update each iteration
    private EstepLinearLol eStep;
    
    // these we will update for each param index
    private int paramIdx;
    private double multFactor;
    private double rescaledTime;
    private double nbar; // nbar as of the previous param (i.e. for first param, nbar = nTrunk)
	
	public MstepLinearLol(ParamSet pSet, double[] times, int nTrunk, int[] paramPattern, double[] descendingAscendingTable) {
    	this.pSet = pSet;
    	this.times = times;
    	this.d = times.length;
    	this.nTrunk = nTrunk;
    	this.paramPattern = paramPattern;
    	this.descendingAscendingTable = descendingAscendingTable;
    }
	
	// compute the log probability of emitting an a from time i (marginalized over haplotypes)
    private double computeLogEmission(int a, CoreLinearOneSize core) {
    	LogSum mutSum = new LogSum(this.pSet.numAlleles());
    	for (int trunkAllele=0; trunkAllele < this.pSet.numAlleles(); trunkAllele++) {
    		mutSum.addLogSummand(core.getLogEmission(trunkAllele, a) + this.pSet.getLogStationaryProb()[trunkAllele]);
    	}
    	return mutSum.retrieveLogSum();
    }
    
    // we want to MAXIMIZE this function (note: this function should not change any fields)
 	public double value(double size) { // this is the size for all intervals spanned by the current param
 		
 		// we are maximizing the total log prob
 		double mStepLogLikelihood = 0.0;
 		
 		// translate the paramIdx in the paramPattern into a range of indices in the discretization
 		int startIdx = 0;
 		for (int i = 0; i < this.paramIdx; i++) {
 			startIdx += this.paramPattern[i];
 		}
 		int endIdx = startIdx + this.paramPattern[this.paramIdx];
 		
 		// set up local variables (we are just testing this size, so don't want to change the globals)
 		double localMultFactor = this.multFactor;
 		double localRescaledTime = this.rescaledTime;
 		double localNbar = this.nbar;
 			
 		// for each discretization interval, add on the relevant terms
 		for (int j = startIdx; j < endIdx; j++) {
 			
			// create the new discretization interval with the relevant size
			double startPoint = this.times[j];
			double endPoint = j < this.d-1 ? this.times[j+1] : Double.POSITIVE_INFINITY;
			double mainWeight = Math.exp(-(endPoint - startPoint) * localNbar / size);
			double newWeight  = localMultFactor * (1 - mainWeight);
			DiscretizationInterval tPoint = new DiscretizationInterval(newWeight, startPoint, endPoint);
			
			// create the new core
			CoreLinearOneSize core = new CoreLinearOneSize(this.pSet, tPoint, size, localNbar);
			if (DEBUG) {
				System.out.println(this.eStep.print(j));
				System.out.println(core.toString());
			}
			
			// first add on the initial marginal probabilities
			mStepLogLikelihood += this.eStep.getExpectedInitialCounts()[j] * Math.log(tPoint.weight);
			
			if (DEBUG) {
				System.out.println("---interval " + j + "---");
	    		System.out.println("initial: " + (this.eStep.getExpectedInitialCounts()[j] * Math.log(tPoint.weight)));
	    	}
			
			// build up the mStep function, not everything is added on for the last time index
			mStepLogLikelihood += this.eStep.getExpectedRecoSameCoalSame()[j] * core.getLogRecoSameCoalSame();
			mStepLogLikelihood += this.eStep.getExpectedNoReco()[j] * core.getLogNoReco();
			
			// for sizeIdx = d-1, all these values should be 0, but for numerical reasons (like 0*-Infinity) they are not, so don't add them on
			if (j < this.d-1) {
				mStepLogLikelihood += this.eStep.getExpectedRecoDiffCoalSame()[j] * core.getLogRecoDiffCoalSame();
				mStepLogLikelihood += this.eStep.getExpectedRecoDiffCoalLater()[j] * core.getLogRecoDiffCoalLater();
				mStepLogLikelihood += this.eStep.getExpectedRecoSameCoalLater()[j] * core.getLogRecoSameCoalLater();
				mStepLogLikelihood += this.eStep.getExpectedCoalNow()[j] * core.getLogCoalNow();
				mStepLogLikelihood += this.eStep.getExpectedCoalLater()[j] * core.getLogCoalLater();
			}
				
			// add emissions
		    for (int a=0; a < this.pSet.numAlleles(); a++) {
		    	mStepLogLikelihood += this.eStep.getExpectedEmissions()[j][a] * computeLogEmission(a,core);
		    }
		    
		    // update local multFactor (do this before updating nbar since we need the previous nbar)
			localMultFactor *= Math.exp(-(endPoint - startPoint) * localNbar / size);
		    
		    // update local rescaled time and nbar
			localRescaledTime += (endPoint - startPoint) / size;
			localNbar = Discretization.nbarTime(this.nTrunk, localRescaledTime, this.descendingAscendingTable);
 		}
 		
 		return mStepLogLikelihood;
 	}
	
	// set DEBUG
	public void setDebug(boolean debug) {
		this.DEBUG = debug;
	}
		
	// update BEFORE each EM iteration, to set the eStep and reset relevant values
	public void updateEachIter(Estep eStep) {
		this.eStep = (EstepLinearLol) eStep;
	   	this.paramIdx = 0;
	   	this.multFactor = 1;
	   	this.rescaledTime = 0;
	   	this.nbar = nTrunk;
	}
	
	// call this once for each parameter after the first one, paramIdx is the index in the paramPattern
	public void updateParamIdx(int paramIdx, double prevSize) {
		assert paramIdx > 0;
		this.paramIdx = paramIdx;
		
		// translate the PREVIOUS paramIdx in the paramPattern into a range of indices in the discretization
		int prevStartIdx = 0;
		for (int i = 0; i < this.paramIdx-1; i++) {
		 	prevStartIdx += this.paramPattern[i];
		}
		int prevEndIdx = prevStartIdx + this.paramPattern[this.paramIdx-1];
			
		// then update the relevant values based on our choice for the PREVIOUS size
		for (int j = prevStartIdx; j < prevEndIdx; j++) {
			
			// find the start and end points of the relevant discretization interval
			double startPoint = this.times[j];
			double endPoint = j < this.d-1 ? this.times[j+1] : Double.POSITIVE_INFINITY;
		
			// update multFactor (do this before updating nbar since we need the previous nbar)
			this.multFactor *= Math.exp(-(endPoint - startPoint) * this.nbar / prevSize);
			
			// update rescaled time and nbar
			this.rescaledTime += (endPoint - startPoint) / prevSize;
			this.nbar = Discretization.nbarTime(this.nTrunk, this.rescaledTime, this.descendingAscendingTable);
		}
	}
}
