/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Utility;

public class MstepQuadPac implements MstepQuad {
	
	private boolean DEBUG = false;
	
	// never update
	private final ParamSet pSet;
    private final double[] times;
    private final int d;
    private final int nTotal;
    private final int[] paramPattern;
    private final double minSize;
	private final double maxSize;
	private final double[][] descendingAscendingTable;
    
	// update each iteration
    private EstepQuadPac eStep;
    
    public MstepQuadPac(ParamSet pSet, double[] times, int nTotal, int[] paramPattern, double minSize, double maxSize, double[][] descendingAscendingTable) {
    	this.pSet = pSet;
    	this.times = times;
    	this.d = times.length;
    	this.nTotal = nTotal;
    	this.paramPattern = paramPattern;
    	this.minSize = minSize;
    	this.maxSize = maxSize;
    	this.descendingAscendingTable = descendingAscendingTable;
    }

    // we want to MAXIMIZE this function
    public double value(double[] sizesShort) {
    	
    	// if we are out of bounds, return infinity
    	if (Utility.outOfBounds(sizesShort, this.minSize, this.maxSize)) {
    		return Double.POSITIVE_INFINITY;
    	}
        
    	// expand sizesShort
    	double[] sizes = Utility.expandSizes(sizesShort, this.paramPattern);
    	
    	// first thing: compute all the necessary csds
    	// we need one set of probabilities for each trunk size (little n will always be the current trunk size)
    	CoreQuad[] allCores = new CoreQuad[this.nTotal-1];
    	for (int n=1; n < this.nTotal; n++) {
    		allCores[n-1] = new CoreQuad(this.pSet, this.times, sizes, n, this.descendingAscendingTable[n-1]);
    	}
    	
    	int numPerms = this.eStep.getAllExpectedTransitions().length;
    	double mStepLogLikelihood = 0.0;
    	
    	// for each permutation in turn:
    	for (int p=0; p < numPerms; p++) {
    		
    		// permQprob is the negative log likelihood for the permutation
    		double permQprob = 0;
    		
    		// sum over the transitions and emissions
    		for (int n=1; n < this.nTotal; n++) {
    			for (int i=0; i < this.d; i++) {
    				
    				// first subtract off the initial marginal probabilities
    				permQprob -= this.eStep.getAllExpectedInitialCounts()[p][n-1][i] * allCores[n-1].getLogInitialWeight(i);
    				
    				if (DEBUG) {
    					System.out.println("---interval " + p + " " + n + " " + i + "---");
						System.out.println("initial: " + (this.eStep.getAllExpectedInitialCounts()[p][n-1][i] * allCores[n-1].getLogInitialWeight(i)));
					}
    				
    				for (int j=0; j < this.d; j++) {
    					permQprob -= this.eStep.getAllExpectedTransitions()[p][n-1][i][j] * allCores[n-1].getLogTransition(i,j);
    				}
        			for (int a=0; a < this.pSet.numAlleles(); a++) {
        				permQprob -= this.eStep.getAllExpectedEmissions()[p][n-1][i][a] * allCores[n-1].getLogMarginalEmission(i,a);
        			}
        		}
    		}
    		
    		// first weight by probability of this permutation (divide by the total probability)
    		permQprob *= Math.exp(this.eStep.getAllLogLikelihoods()[p] - this.eStep.getEstepLogLikelihood());
    		
    		// then add (subtract since permQprob is negative log likelihood) the contribution from this permutation
    		mStepLogLikelihood -= permQprob;
    	}
    	
	    return mStepLogLikelihood;
    }
    
    // set DEBUG
  	public void setDebug(boolean debug) {
  		this.DEBUG = debug;
  	}
    
    // update the Estep before each iteration
    public void updateEachIter(Estep eStep){
    	this.eStep = (EstepQuadPac) eStep;
    }
}
