/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 * 
 * based on:
 * Copyright (c) 2011, Joshua Paul, Matthias SteinrŸcken, Yun Song
 */

package edu.berkeley.smcsd;

import java.util.Arrays;

import edu.berkeley.utility.Discretization;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Utility;
import edu.berkeley.utility.Discretization.DiscretizationInterval;
import Jama.Matrix;

// original core with variable population size
public class CoreQuad {
	
	private final static int NUM_SUM = 500;
	
	private final ParamSet pSet;
	private final double[] sizes;
	private final int nTrunk;
	private final double[] nbar;
	private final DiscretizationInterval[] tPoints;
	
	// no recombination matrix: indexed by t-index (src)
	private final double[] logNoRecMatrices;
	
	// recombination matrix: indexed by time-index (src), time-index (dst)
	private final double[][] logRecMatrices;
	
	// index by locus, t-index, allele (src), allele (dst)
	private final double[][][] logMutMatrices;
	
	public CoreQuad(ParamSet pSet, double[] times, double[] sizes, int nTrunk, double[] descendingAscendingTable) {

		this.pSet = pSet;
		this.sizes = sizes;
		this.nTrunk = nTrunk;
		this.nbar = Discretization.computeNbarArray(times, sizes, nTrunk, descendingAscendingTable);
		this.tPoints = Discretization.makeDiscretizationPoints(times, sizes, nTrunk, this.nbar);
		
		// compute recombination matrices 
		this.logNoRecMatrices = computeLogNoRecMatrix();
		this.logRecMatrices = computeLogRecMatrix();
		
		// compute mutation matrices 
		this.logMutMatrices = computeLogQMatrix();
	}
	
	//----------------
	// PUBLIC GETTERS
	//----------------
	
	public int numIntervals() { return this.tPoints.length; }
	public double[] getPopSizes() { return this.sizes; }
	public int nTrunk() { return this.nTrunk; }
	public int numAlleles() { return this.pSet.numAlleles(); }
	
	// returns the start points of all the quadrature points
	public double[] getPoints() {
		double[] timeIdxMap = new double[numIntervals()];
		for (int i=0; i < numIntervals(); i++) {
			timeIdxMap[i] = this.tPoints[i].startPoint;
		}
		return timeIdxMap;
	}
	
	// quadrature weight getters
	public double getLogInitialWeight(int tIdx) {
		return Math.log(this.tPoints[tIdx].weight);
	}
	
	// no rec getter
	public double getLogNoRecombinationTransition(int tIdx) {
		return this.logNoRecMatrices[tIdx];
	}
	
	// rec getter
	public double getLogRecombinationTransition(int srcTIdx, int dstTIdx) {
		return this.logRecMatrices[srcTIdx][dstTIdx];
	}
	
	// emission getter, modifying to handle missing data (emit 1)
	public double getLogEmission(int srcAllele, int dstAllele, int tIdx) {
		// if dstAllele is unknown (N), probability is 1
		int unknown = this.pSet.numAlleles(); // last allele always represents unknown base
		if (dstAllele == unknown) {
			return 0; // log(1)
			
		// otherwise if srcAllele is unknown, sum over all possible values for srcAllele
		} else if (srcAllele == unknown) {
			LogSum mutSum = new LogSum(pSet.numAlleles());
	    	for (int trunkAllele=0; trunkAllele < pSet.numAlleles(); trunkAllele++) {
	    		mutSum.addLogSummand(this.logMutMatrices[tIdx][trunkAllele][dstAllele] + this.pSet.getLogStationaryProb()[trunkAllele]);
	    	}
	    	return mutSum.retrieveLogSum();
		}
		
		// if we know both srcAllele and dstAllele, just return from the matrix
		return this.logMutMatrices[tIdx][srcAllele][dstAllele];
	}
	
	// get transition probability (including reco and no reco)
    public double getLogTransition(int srcTIdx, int dstTIdx) {
    	double logRecTrans = getLogRecombinationTransition(srcTIdx, dstTIdx);
    	if (srcTIdx == dstTIdx) {
    		LogSum newSum = new LogSum(2);
    		newSum.addLogSummand(logRecTrans);
    		newSum.addLogSummand(getLogNoRecombinationTransition(srcTIdx));
    		return newSum.retrieveLogSum();
    	} else {
    		return logRecTrans;
    	}
    }
    
    // compute the log probability of emitting an a from time i
    public double getLogMarginalEmission(int tIdx, int allele) {
    	LogSum mutSum = new LogSum(this.pSet.numAlleles());
    	for (int trunkAllele=0; trunkAllele < this.pSet.numAlleles(); trunkAllele++) {
    		mutSum.addLogSummand(getLogEmission(trunkAllele, allele, tIdx) + this.pSet.getLogStationaryProb()[trunkAllele]);
    	}
    	return mutSum.retrieveLogSum();
    }
	
	//--------------------------------------------------
	// PRIVATE FUNCTIONS FOR COMPUTING TRANSITION PROBS
	//--------------------------------------------------
	
	private double[][] computeLogRecMatrix() {
		
		int d = this.tPoints.length;
		
		// precompute R values, not the last one, not needed
		double[] Rvalues = new double[d-1];
		for (int k=0; k < d-1; k++) {
			Rvalues[k] = computeR(k);
		}
		
		double[][] RsumValues = new double[d][d];
		for (int srcTIndex = 0; srcTIndex < d; srcTIndex++) {
			for (int dstTIndex = 0; dstTIndex < d; dstTIndex++) {
				for (int k=0; k <= Math.min(srcTIndex,dstTIndex)-1; k++) {
					RsumValues[srcTIndex][dstTIndex] += Rvalues[k];
				}
			}
		}
		
		double[][] logRecMatrices = new double[d][d];
		for (int srcTIndex = 0; srcTIndex < d; srcTIndex++) {
			for (int dstTIndex = 0; dstTIndex < d; dstTIndex++) {
				logRecMatrices[srcTIndex][dstTIndex] = Math.log(computeRecombinationTransition(srcTIndex, dstTIndex, RsumValues));
			}
		}
		
		return logRecMatrices;
	}
	
	// i is start time interval index and j is end time interval index
	private double computeRecombinationTransition(int i, int j, double[][] RsumValues) {
		double recRate = this.pSet.getRecombinationRate();
		
		double wi = this.tPoints[i].weight;
		double si = this.tPoints[i].startPoint;
		double ei = this.tPoints[i].endPoint;
		
		double wj = this.tPoints[j].weight;
		double sj = this.tPoints[j].startPoint;
		double ej = this.tPoints[j].endPoint;
		
		double bigZ = Double.NaN;
		
		// case 1: i < j
		if (i < j) {
				
			double a = Math.exp(-si*recRate);
			double b = sizes[i]*recRate/(nbar[i]-sizes[i]*recRate) * Math.exp(-(ei-si)*nbar[i]/sizes[i] - si*recRate);
			double c = - nbar[i]/(nbar[i]-sizes[i]*recRate)*Math.exp(-ei*recRate);
				
			bigZ = wj*(a+b+c)/wi;
				
		// case 2: i > j
		} else if (i > j) {
				
			double a = Math.exp(-sj*recRate);
			double b = sizes[j]*recRate/(nbar[j]-sizes[j]*recRate) * Math.exp(-(ej-sj)*nbar[j]/sizes[j] - sj*recRate);
			double c = - nbar[j]/(nbar[j]-sizes[j]*recRate)*Math.exp(-ej*recRate);
					
			bigZ = a+b+c;
			
		// case 3: i = j
		} else {
					
			double a = recRate*sizes[i]/(nbar[i]+recRate*sizes[i])*Math.exp(-si*recRate);
			double b = -2*Math.exp(-(ei-si)*nbar[i]/sizes[i] - si*recRate);
			double c = -recRate*sizes[i]/(nbar[i]-sizes[i]*recRate)*Math.exp(-si*recRate - 2*(ei-si)*nbar[i]/sizes[i]);
			double d = 2*Math.pow(nbar[i],2)/((nbar[i]-sizes[i]*recRate)*(nbar[i]+sizes[i]*recRate))*Math.exp(-ei*recRate - (ei-si)*nbar[i]/sizes[i]);
			double e = (1-Math.exp(-(ei-si)*nbar[i]/sizes[i]));
						
			bigZ = (a + b + c + d)/e;
		}
		
		// adding check for small probability
		double prob = bigZ + wj*RsumValues[i][j];
		if (prob < 0) {
			System.out.println("setting small negative transition probability: " + prob + " to zero (time " + i + " to " + j + ")");
			return 0;
		}
		return prob;
	}
	
	private double computeR(int k) {
		double recRate = this.pSet.getRecombinationRate();
		assert k != numIntervals()-1; // make sure we're not calling on last interval
		
		double prod = 1;
		for (int i=0; i < k; i++) {
			prod *= Math.exp((tPoints[i].endPoint - tPoints[i].startPoint)*nbar[i]/sizes[i]);
		}
		
		double sk = this.tPoints[k].startPoint;
		double ek = this.tPoints[k].endPoint;
		
		double Rvalue = prod * sizes[k]*recRate/(nbar[k]-sizes[k]*recRate) * (Math.exp((ek - sk)*nbar[k]/sizes[k] - ek*recRate) - Math.exp(-sk*recRate));
		return Rvalue;
	}
	
	private double[] computeLogNoRecMatrix() {
		double[] logNoRecMatrices = new double[this.tPoints.length];
		
		for (int tIndex = 0; tIndex < this.tPoints.length; tIndex++) {
			logNoRecMatrices[tIndex] = Math.log(computeNoRecombinationTransition(tIndex));
		}
		
		return logNoRecMatrices;
	}
	
	private double computeNoRecombinationTransition(int i) {
		double recRate = this.pSet.getRecombinationRate();
		double si = tPoints[i].startPoint;
		double ei = tPoints[i].endPoint;
		
		double coeff = this.nbar[i]/(this.nbar[i] + this.sizes[i]*recRate);
		double frac = (Math.exp(-si*recRate) - Math.exp(-ei*recRate - (ei-si)*this.nbar[i]/this.sizes[i])) / (1 - Math.exp(-(ei-si)*this.nbar[i]/this.sizes[i]));
		return coeff * frac;
	}
	
	//------------------------------------------------
	// PRIVATE FUNCTIONS FOR COMPUTING EMISSION PROBS
	//------------------------------------------------
	
	private double[][][] computeQMatrix() {
		
		double[][][] qMatrices = new double[tPoints.length][][];
		int numAlleles = pSet.numAlleles();
			
		for (int tIdx = 0; tIdx < tPoints.length; tIdx++) {
			qMatrices[tIdx] = aprxMatrixIntegralDiagonal(tIdx).getArray();

			double minVal = Double.POSITIVE_INFINITY;
			double maxVal = Double.NEGATIVE_INFINITY;

			for (int a1 = 0; a1 < numAlleles; a1++) {
				for (int a2 = 0; a2 < numAlleles; a2++) {
					maxVal = Math.max(qMatrices[tIdx][a1][a2], maxVal);
					minVal = Math.min(qMatrices[tIdx][a1][a2], minVal);
				}
			}

			if (maxVal > 1 || minVal < 0 || Double.isNaN(maxVal)) {
				System.out.println("Problem computing Q-matrices: replacing invalid entry with uniform transition matrix");
				System.out.println("bad sizes? " + Arrays.toString(sizes));
				System.out.println("bad time index? " + tIdx);
				System.out.println("matrix " + Utility.getMatrixString(qMatrices[tIdx]));
				for (int a1 = 0; a1 < numAlleles; a1++) {
					for (int a2 = 0; a2 < numAlleles; a2++) {
						qMatrices[tIdx][a1][a2] = 1/(double)numAlleles;
					}
				}
			}
		}
	
		return qMatrices;
	}
	
	// compute log of qMatrices
	private double[][][] computeLogQMatrix() {
		double[][][] qMatrices = computeQMatrix();
		double[][][] qLogMatrices = new double[tPoints.length][pSet.numAlleles()][pSet.numAlleles()];
		
		for (int t = 0; t < tPoints.length; t++) {

			for (int a1 = 0; a1 < pSet.numAlleles(); a1++) {
				for (int a2 = 0; a2 < pSet.numAlleles(); a2++) {
					qLogMatrices[t][a1][a2] = Math.log(qMatrices[t][a1][a2]);
				}
			}
		}
		return qLogMatrices;
	}
	
	// use diagonalization technique to compute matrix exponential
	private Matrix aprxMatrixIntegralDiagonal(int tIdx) {
		double cConst = this.pSet.getMutationRate() + this.nbar[tIdx]/this.sizes[tIdx];
		DiscretizationInterval tPoint = this.tPoints[tIdx];
		
		Matrix resultMatrix = Matrix.identity(this.pSet.numAlleles(), this.pSet.numAlleles());
		
		// new way of computing constants we need to multiply and divide (old "multFactor" cancels out)
		double weight = (1-Math.exp(-(tPoint.endPoint-tPoint.startPoint)*this.nbar[tIdx]/this.sizes[tIdx]));
		double nConst = this.nbar[tIdx]/this.sizes[tIdx] * Math.exp(tPoint.startPoint*this.nbar[tIdx]/this.sizes[tIdx]);
		
		// loop over each diagonal entry of the diagonal matrix
		for (int a=0; a < this.pSet.numAlleles(); a++) {
			
			double currM = 1;
			double currQ = getRVal(tIdx, 0, cConst);			
			double result = currQ;
			
			// perform the "infinite" sum
			for (int m=1; m < NUM_SUM; m++) {
				currQ = getRVal(tIdx, m, cConst) + m / cConst * currQ;
				currM *= this.pSet.getMutationRate() / m * this.pSet.getMutationDiagonal().get(a,a);
				
				// only add on if a real number or not infinity
				if (!Float.isNaN((float)(currM * currQ)) && !Float.isInfinite((float)(currM * currQ))) {
					result += currM * currQ;
				}
			}
			resultMatrix.set(a,a,result*nConst/weight);
		}
		
		// multiply diagonal by U on left and Uinv on right
		resultMatrix = this.pSet.getMutationU().times(resultMatrix); 
		resultMatrix = resultMatrix.times(this.pSet.getMutationUinverse());
		
		return resultMatrix;
	}
	
	// this is where the terms of v are added on (Eq. A10), i is the time index
	private double getRVal(int tIdx, long k, double cConst) {
		DiscretizationInterval tPoint = this.tPoints[tIdx];
		double start = Math.exp(-cConst * tPoint.startPoint) * Math.pow(tPoint.startPoint, k);
		double end = Double.isInfinite(tPoint.endPoint) ? 0 : Math.exp(-cConst * tPoint.endPoint) * Math.pow(tPoint.endPoint, k);
		return 1 / cConst * (start - end);
	}
}
