/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Utility;

public class MstepQuadLol implements MstepQuad {
	
	private boolean DEBUG = false;
	
	// never update
	private final ParamSet pSet;
    private final double[] times;
    private final int d;
    private final int nTrunk;
    private int[] paramPattern;
    private final double minSize;
    private final double maxSize;
    private final double[] descendingAscendingTable;
    
    // update each iteration
    private EstepQuadLol eStep;
    
    public MstepQuadLol(ParamSet pSet, double[] times, int nTrunk, int[] paramPattern, double minSize, double maxSize, double[] descendingAscendingTable) {
    	this.pSet = pSet;
    	this.times = times;
    	this.d = times.length;
    	this.nTrunk = nTrunk;
    	this.paramPattern = paramPattern;
    	this.minSize = minSize;
    	this.maxSize = maxSize;
    	this.descendingAscendingTable = descendingAscendingTable;
    }

    // we want to MAXIMIZE this function
    public double value(double[] sizesShort) {
    	
    	// if we are out of bounds, return infinity
    	if (Utility.outOfBounds(sizesShort, this.minSize, this.maxSize)) {
    		return Double.POSITIVE_INFINITY;
    	}
 
    	// expand sizesShort
    	double[] sizes = Utility.expandSizes(sizesShort, this.paramPattern);
    	
    	// first thing: compute the new csd
    	CoreQuad core = new CoreQuad(this.pSet, this.times, sizes, this.nTrunk, descendingAscendingTable);
    	        
    	// we are maximizing the total log prob
    	double mStepLogLikelihood = 0.0;
    	
	    for (int i=0; i < this.d; i++) {
	    	
	    	// first add on the initial marginal probabilities
	    	mStepLogLikelihood += this.eStep.getExpectedInitialCounts()[i] * core.getLogInitialWeight(i);
	    	
	    	if (DEBUG) {
	    		System.out.println("---interval " + i + "---");
	    		System.out.println("initial: " + (this.eStep.getExpectedInitialCounts()[i] * core.getLogInitialWeight(i)));
	    	}
	    	
	    	// subtract off the transitions
	        for (int j=0; j < this.d; j++) {
	        	mStepLogLikelihood += this.eStep.getExpectedTransitions()[i][j] * core.getLogTransition(i,j);
	        }
	    
	    	// subtract off the mutations
	    	for (int a=0; a < this.pSet.numAlleles(); a++) {
	    		mStepLogLikelihood += this.eStep.getExpectedEmissions()[i][a] * core.getLogMarginalEmission(i,a);
	    	}
	    }
	    return mStepLogLikelihood;
    }
    
    // set DEBUG
   	public void setDebug(boolean debug) {
   		this.DEBUG = debug;
   	}
    
    // update the Estep before each iteration
    public void updateEachIter(Estep eStep){
    	this.eStep = (EstepQuadLol) eStep;
    }
}
